llama_core/chat/responses.rs
1//! Define APIs for chat completion.
2#[allow(unused_imports)]
3use crate::{
4 error,
5 metadata::ggml::GgmlMetadata,
6 running_mode,
7 utils::{
8 gen_chat_id, gen_response_id, get_output_buffer, get_output_buffer_single,
9 get_token_info_by_graph, get_token_info_by_graph_name, set_tensor_data_u8,
10 },
11 Graph, RunningMode, CACHED_UTF8_ENCODINGS, CHAT_GRAPHS, OUTPUT_TENSOR,
12};
13use chat_prompts::{BuildChatPrompt, ChatPrompt, PromptTemplateType};
14use either::{Either, Left, Right};
15use endpoints::{
16 // chat::{
17 // ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionChunkChoiceDelta,
18 // ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage,
19 // ChatCompletionRequest, ChatCompletionRequestMessage, ChatCompletionRole,
20 // ChatCompletionUserMessageContent, ContentPart, Function, ToolCall, ToolCallForChunk,
21 // ToolChoice,
22 // },
23 chat::{ChatCompletionRequestMessage, ChatCompletionRole, ChatCompletionUserMessageContent},
24 responses::{
25 items::{ResponseOutputItem, ResponseOutputItemOutputMessageContent},
26 response_object::{
27 Input, InputItem, InputMessageContent, InputTokensDetails, OutputTokensDetails,
28 RequestOfModelResponse, ResponseObject, ToolChoice, Usage,
29 },
30 },
31 // common::{FinishReason, Usage},
32};
33use error::{BackendError, LlamaCoreError};
34use std::{
35 collections::{HashMap, VecDeque},
36 pin::Pin,
37 sync::{
38 atomic::{AtomicBool, Ordering},
39 Mutex, OnceLock,
40 },
41 task::{Context, Poll, Waker},
42 time::SystemTime,
43};
44
45// Define a global waker queue for storing waiting ChatStreams
46static CHAT_STREAM_WAKER_QUEUE: OnceLock<Mutex<VecDeque<Waker>>> = OnceLock::new();
47
48// Define a global atomic boolean indicating whether there is an active ChatStream
49static CHAT_STREAM_ACTIVE: AtomicBool = AtomicBool::new(false);
50
51/// Processes a chat-completion request and returns either a stream of ChatCompletionChunk instances or a ChatCompletionObject instance.
52pub async fn chat(
53 chat_request: &mut RequestOfModelResponse,
54) -> Result<
55 (
56 Either<impl futures::TryStream<Ok = String, Error = LlamaCoreError>, ResponseObject>,
57 bool,
58 ),
59 LlamaCoreError,
60> {
61 #[cfg(feature = "logging")]
62 {
63 debug!(target: "stdout", "tool choice: {:?}", &chat_request.tool_choice);
64 debug!(target: "stdout", "tools: {:?}", chat_request.tools.as_ref());
65 debug!(target: "stdout", "stream mode: {:?}", chat_request.stream);
66 }
67
68 let result = match chat_request.stream {
69 Some(true) => match chat_stream(chat_request).await {
70 Ok((stream, include_tool_calls)) => Ok((Left(stream), include_tool_calls)),
71 Err(e) => Err(e),
72 },
73 Some(false) | None => match chat_once(chat_request).await {
74 Ok((chat_completion_object, include_tool_calls)) => {
75 Ok((Right(chat_completion_object), include_tool_calls))
76 }
77 Err(e) => Err(e),
78 },
79 };
80
81 #[cfg(feature = "logging")]
82 info!(target: "stdout", "Reset the model metadata");
83
84 result
85}
86
87async fn chat_stream(
88 chat_request: &mut RequestOfModelResponse,
89) -> Result<
90 (
91 impl futures::TryStream<Ok = String, Error = LlamaCoreError>,
92 bool,
93 ),
94 LlamaCoreError,
95> {
96 #[cfg(feature = "logging")]
97 info!(target: "stdout", "Process chat completion request in the stream mode");
98
99 let running_mode = running_mode()?;
100 if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
101 let err_msg = "The chat completion is only supported in the chat or rag mode.";
102
103 #[cfg(feature = "logging")]
104 error!(target: "stdout", "{err_msg}");
105
106 return Err(LlamaCoreError::Operation(err_msg.to_string()));
107 }
108
109 let model_name = chat_request.model.clone();
110
111 #[cfg(feature = "logging")]
112 info!(target: "stdout", "Check model metadata");
113
114 // update metadata
115 let mut metadata = check_model_metadata(chat_request)?;
116
117 // parse the `include_usage` option
118 // let include_usage = match chat_request.stream_options {
119 // Some(ref stream_options) => stream_options.include_usage.unwrap_or_default(),
120 // None => metadata.include_usage,
121 // };
122 // #[cfg(feature = "logging")]
123 // info!(target: "stdout", "include_usage: {include_usage}");
124
125 #[cfg(feature = "logging")]
126 info!(target: "stdout", "Build the chat prompt");
127
128 // build prompt
129 let (prompt, avaible_completion_tokens, tool_use) =
130 build_prompt(model_name.as_ref(), chat_request)?;
131
132 #[cfg(feature = "logging")]
133 {
134 info!(target: "stdout", "prompt:\n{}", &prompt);
135 info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
136 info!(target: "stdout", "tool_use: {tool_use}");
137 }
138
139 #[cfg(feature = "logging")]
140 info!(target: "stdout", "Update the n_predict");
141
142 // update metadata n_predict
143 update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
144
145 #[cfg(feature = "logging")]
146 info!(target: "stdout", "Feed the prompt to the model");
147
148 // set prompt
149 set_prompt(chat_request.model.as_ref(), &prompt)?;
150
151 let stream = match tool_use {
152 false => (ChatStream::new(model_name, None), false),
153 true => {
154 todo!("Implement the streaming with tool use")
155
156 // let chat_graphs = match CHAT_GRAPHS.get() {
157 // Some(chat_graphs) => chat_graphs,
158 // None => {
159 // let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
160
161 // #[cfg(feature = "logging")]
162 // error!(target: "stdout", "{}", &err_msg);
163
164 // return Err(LlamaCoreError::Operation(err_msg.into()));
165 // }
166 // };
167
168 // let mut chat_graphs = chat_graphs.lock().map_err(|e| {
169 // let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
170
171 // #[cfg(feature = "logging")]
172 // error!(target: "stdout", "{}", &err_msg);
173
174 // LlamaCoreError::Operation(err_msg)
175 // })?;
176
177 // match model_name {
178 // Some(model_name) => match chat_graphs.contains_key(&model_name) {
179 // true => {
180 // let graph = chat_graphs.get_mut(&model_name).unwrap();
181 // chat_stream_for_tool(graph, id, include_usage)?
182 // }
183 // false => match chat_graphs.iter_mut().next() {
184 // Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
185 // None => {
186 // let err_msg = "There is no model available in the chat graphs.";
187
188 // #[cfg(feature = "logging")]
189 // error!(target: "stdout", "{}", &err_msg);
190
191 // return Err(LlamaCoreError::Operation(err_msg.into()));
192 // }
193 // },
194 // },
195 // None => match chat_graphs.iter_mut().next() {
196 // Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
197 // None => {
198 // let err_msg = "There is no model available in the chat graphs.";
199
200 // #[cfg(feature = "logging")]
201 // error!(target: "stdout", "{}", &err_msg);
202
203 // return Err(LlamaCoreError::Operation(err_msg.into()));
204 // }
205 // },
206 // }
207 }
208 };
209
210 #[cfg(feature = "logging")]
211 info!(target: "stdout", "End of the chat completion stream.");
212
213 Ok(stream)
214}
215
216async fn chat_once(
217 chat_request: &mut RequestOfModelResponse,
218) -> Result<(ResponseObject, bool), LlamaCoreError> {
219 #[cfg(feature = "logging")]
220 info!(target: "stdout", "Processing chat completion request in non-stream mode");
221
222 let running_mode = running_mode()?;
223 if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
224 let err_msg = "The chat completion is only supported in the chat or rag mode.";
225
226 #[cfg(feature = "logging")]
227 error!(target: "stdout", "{err_msg}");
228
229 return Err(LlamaCoreError::Operation(err_msg.to_string()));
230 }
231
232 let model_name = chat_request.model.clone();
233
234 // let id = match &chat_request.user {
235 // Some(id) => id.clone(),
236 // None => gen_chat_id(),
237 // };
238
239 // #[cfg(feature = "logging")]
240 // info!(target: "stdout", "user: {}", &id);
241
242 #[cfg(feature = "logging")]
243 info!(target: "stdout", "Check model metadata");
244
245 // update metadata
246 let mut metadata = check_model_metadata(chat_request)?;
247
248 #[cfg(feature = "logging")]
249 info!(target: "stdout", "Build the chat prompt");
250
251 // build prompt
252 let (prompt, avaible_completion_tokens, tool_use) =
253 build_prompt(model_name.as_ref(), chat_request)?;
254
255 #[cfg(feature = "logging")]
256 {
257 info!(target: "stdout", "prompt:\n{}", &prompt);
258 info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
259 info!(target: "stdout", "tool_use: {tool_use}");
260 }
261
262 #[cfg(feature = "logging")]
263 info!(target: "stdout", "Update n_predict");
264
265 // update metadata n_predict
266 update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
267
268 #[cfg(feature = "logging")]
269 info!(target: "stdout", "Feed the prompt to the model");
270
271 // feed the prompt to the model
272 set_prompt(model_name.as_ref(), &prompt)?;
273
274 #[cfg(feature = "logging")]
275 info!(target: "stdout", "Compute chat completion.");
276
277 // compute
278 let res = compute(chat_request, model_name.as_ref(), tool_use);
279
280 #[cfg(feature = "logging")]
281 info!(target: "stdout", "End of the chat completion");
282
283 // reset the model metadata
284 reset_model_metadata(model_name.as_ref())?;
285
286 res
287}
288
289fn compute(
290 chat_request: &mut RequestOfModelResponse,
291 model_name: Option<&String>,
292 tool_use: bool,
293) -> Result<(ResponseObject, bool), LlamaCoreError> {
294 let chat_graphs = match CHAT_GRAPHS.get() {
295 Some(chat_graphs) => chat_graphs,
296 None => {
297 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
298
299 #[cfg(feature = "logging")]
300 error!(target: "stdout", "{}", &err_msg);
301
302 return Err(LlamaCoreError::Operation(err_msg.into()));
303 }
304 };
305
306 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
307 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
308
309 #[cfg(feature = "logging")]
310 error!(target: "stdout", "{}", &err_msg);
311
312 LlamaCoreError::Operation(err_msg)
313 })?;
314
315 match model_name {
316 Some(model_name) => match chat_graphs.contains_key(model_name) {
317 true => {
318 let graph = chat_graphs.get_mut(model_name).unwrap();
319 compute_by_graph(chat_request, graph, tool_use)
320 }
321 false => match chat_graphs.iter_mut().next() {
322 Some((_, graph)) => compute_by_graph(chat_request, graph, tool_use),
323 None => {
324 let err_msg = "There is no model available in the chat graphs.";
325
326 #[cfg(feature = "logging")]
327 error!(target: "stdout", "{}", &err_msg);
328
329 Err(LlamaCoreError::Operation(err_msg.into()))
330 }
331 },
332 },
333 None => match chat_graphs.iter_mut().next() {
334 Some((_, graph)) => compute_by_graph(chat_request, graph, tool_use),
335 None => {
336 let err_msg = "There is no model available in the chat graphs.";
337
338 #[cfg(feature = "logging")]
339 error!(target: "stdout", "{}", &err_msg);
340
341 Err(LlamaCoreError::Operation(err_msg.into()))
342 }
343 },
344 }
345}
346
347fn compute_by_graph(
348 chat_request: &mut RequestOfModelResponse,
349 graph: &mut Graph<GgmlMetadata>,
350 tool_use: bool,
351) -> Result<(ResponseObject, bool), LlamaCoreError> {
352 #[cfg(feature = "logging")]
353 info!(target: "stdout", "Compute chat completion by the model named {}.", graph.name());
354
355 match graph.compute() {
356 Ok(_) => {
357 // Retrieve the output.
358 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
359 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
360 let err_msg = format!(
361 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
362 );
363
364 #[cfg(feature = "logging")]
365 error!(target: "stdout", "{}", &err_msg);
366
367 LlamaCoreError::Operation(err_msg)
368 })?;
369
370 #[cfg(feature = "logging")]
371 info!(target: "stdout", "raw generation: {output}");
372
373 // post-process
374 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
375 LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
376 })?;
377
378 #[cfg(feature = "logging")]
379 info!(target: "stdout", "post-processed generation:\n{}", &message);
380
381 // retrieve the number of prompt and completion tokens
382 let token_info = get_token_info_by_graph(graph)?;
383
384 #[cfg(feature = "logging")]
385 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
386
387 let created = SystemTime::now()
388 .duration_since(std::time::UNIX_EPOCH)
389 .map_err(|e| {
390 let err_msg = format!("Failed to get the current time. Reason: {e}");
391
392 #[cfg(feature = "logging")]
393 error!(target: "stdout", "{}", &err_msg);
394
395 LlamaCoreError::Operation(err_msg)
396 })?;
397
398 match tool_use {
399 true => {
400 if graph.metadata.prompt_template != PromptTemplateType::MistralTool
401 && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
402 && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
403 && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
404 && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
405 && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
406 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
407 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
408 && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
409 && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
410 && graph.metadata.prompt_template != PromptTemplateType::Qwen3NoThink
411 && graph.metadata.prompt_template != PromptTemplateType::Smol3NoThink
412 && graph.metadata.prompt_template != PromptTemplateType::Gemma3
413 && graph.metadata.prompt_template != PromptTemplateType::GptOss
414 && graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent
415 && graph.metadata.prompt_template != PromptTemplateType::SeedOssNoThink
416 && graph.metadata.prompt_template != PromptTemplateType::SeedOssThink
417 {
418 let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', 'llama-4-chat', 'qwen3-no-think', 'smol-3-no-think', 'gemma-3', 'gpt-oss', 'qwen3-agent', 'seed-oss-no-think', and 'seed-oss-think' prompt templates.", graph.metadata.prompt_template);
419
420 #[cfg(feature = "logging")]
421 error!(target: "stdout", "{}", &err_msg);
422
423 return Err(LlamaCoreError::Operation(err_msg));
424 }
425
426 // let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
427
428 // let (finish_reason, content, include_tool_calls) =
429 // if parsed_result.tool_calls.is_empty() {
430 // (FinishReason::stop, Some(parsed_result.raw.clone()), false)
431 // } else if graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent {
432 // (
433 // FinishReason::tool_calls,
434 // Some(parsed_result.raw.clone()),
435 // true,
436 // )
437 // } else {
438 // (
439 // FinishReason::tool_calls,
440 // parsed_result.content.clone(),
441 // true,
442 // )
443 // };
444
445 // let res = ChatCompletionObject {
446 // id: id.into(),
447 // object: String::from("chat.completion"),
448 // created: created.as_secs(),
449 // model: graph.name().to_owned(),
450 // choices: vec![ChatCompletionObjectChoice {
451 // index: 0,
452 // message: ChatCompletionObjectMessage {
453 // role: ChatCompletionRole::Assistant,
454 // content,
455 // tool_calls: parsed_result.tool_calls,
456 // function_call: None,
457 // },
458 // finish_reason,
459 // logprobs: None,
460 // }],
461 // usage: Usage {
462 // prompt_tokens: token_info.prompt_tokens,
463 // completion_tokens: token_info.completion_tokens,
464 // total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
465 // },
466 // };
467
468 // // create ChatCompletionResponse
469 // Ok((res, include_tool_calls))
470
471 unimplemented!("Return ResponseObject with tool calls")
472 }
473 false => {
474 let message = ResponseOutputItemOutputMessageContent::OutputText {
475 annotations: vec![],
476 text: message.to_string(),
477 ty: "output_text".to_string(),
478 logprobs: None,
479 };
480
481 let output_message = ResponseOutputItem::OutputMessage {
482 content: vec![message],
483 id: "msg_67ccd3acc8d48190a77525dc6de64b4104becb25c45c1d41".to_string(),
484 role: ChatCompletionRole::Assistant.to_string(),
485 status: "completed".to_string(),
486 ty: "message".to_string(),
487 };
488
489 let temperature = match &chat_request.temperature {
490 Some(t) => *t,
491 None => graph.metadata.temperature,
492 };
493
494 let top_p = match &chat_request.top_p {
495 Some(t) => *t,
496 None => graph.metadata.top_p,
497 };
498
499 let res = ResponseObject {
500 background: false,
501 conversation: None,
502 created_at: created.as_secs(),
503 error: None,
504 id: gen_response_id(),
505 incomplete_details: None,
506 instructions: None,
507 max_output_tokens: None,
508 max_tool_calls: None,
509 metadata: HashMap::new(),
510 model: graph.name().to_owned(),
511 object: "response".to_string(),
512 output: vec![output_message],
513 parallel_tool_calls: true,
514 previous_response_id: None,
515 safety_identifier: None,
516 status: "completed".to_string(),
517 temperature,
518 tool_choice: chat_request.tool_choice.clone(),
519 tools: chat_request.tools.clone(),
520 top_p,
521 truncation: None,
522 usage: Usage {
523 input_tokens: token_info.prompt_tokens,
524 input_tokens_details: InputTokensDetails { cached_tokens: 0 },
525 output_tokens: token_info.completion_tokens,
526 output_tokens_details: OutputTokensDetails {
527 reasoning_tokens: 0,
528 },
529 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
530 },
531 };
532
533 Ok((res, false))
534 }
535 }
536 }
537 Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
538 // Retrieve the output.
539 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
540 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
541 let err_msg = format!(
542 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
543 );
544
545 #[cfg(feature = "logging")]
546 error!(target: "stdout", "{}", &err_msg);
547
548 LlamaCoreError::Operation(err_msg)
549 })?;
550
551 // post-process
552 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
553 let err_msg = format!("Failed to post-process the output. {e}");
554
555 #[cfg(feature = "logging")]
556 error!(target: "stdout", "{}", &err_msg);
557
558 LlamaCoreError::Operation(err_msg)
559 })?;
560
561 // retrieve the number of prompt and completion tokens
562 let token_info = get_token_info_by_graph(graph)?;
563
564 #[cfg(feature = "logging")]
565 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
566
567 let created = SystemTime::now()
568 .duration_since(std::time::UNIX_EPOCH)
569 .map_err(|e| {
570 let err_msg = format!("Failed to get the current time. Reason: {e}");
571
572 #[cfg(feature = "logging")]
573 error!(target: "stdout", "{}", &err_msg);
574
575 LlamaCoreError::Operation(err_msg)
576 })?;
577
578 let message = ResponseOutputItemOutputMessageContent::OutputText {
579 annotations: vec![],
580 text: message.to_string(),
581 ty: "output_text".to_string(),
582 logprobs: None,
583 };
584
585 let output_message = ResponseOutputItem::OutputMessage {
586 content: vec![message],
587 id: "msg_67ccd3acc8d48190a77525dc6de64b4104becb25c45c1d41".to_string(),
588 role: ChatCompletionRole::Assistant.to_string(),
589 status: "completed".to_string(),
590 ty: "message".to_string(),
591 };
592
593 let temperature = match &chat_request.temperature {
594 Some(t) => *t,
595 None => graph.metadata.temperature,
596 };
597
598 let top_p = match &chat_request.top_p {
599 Some(t) => *t,
600 None => graph.metadata.top_p,
601 };
602
603 let res = ResponseObject {
604 background: false,
605 conversation: None,
606 created_at: created.as_secs(),
607 error: None,
608 id: gen_response_id(),
609 incomplete_details: None,
610 instructions: None,
611 max_output_tokens: None,
612 max_tool_calls: None,
613 metadata: HashMap::new(),
614 model: graph.name().to_owned(),
615 object: "response".to_string(),
616 output: vec![output_message],
617 parallel_tool_calls: true,
618 previous_response_id: None,
619 safety_identifier: None,
620 status: "completed".to_string(),
621 temperature,
622 tool_choice: chat_request.tool_choice.clone(),
623 tools: chat_request.tools.clone(),
624 top_p,
625 truncation: None,
626 usage: Usage {
627 input_tokens: token_info.prompt_tokens,
628 input_tokens_details: InputTokensDetails { cached_tokens: 0 },
629 output_tokens: token_info.completion_tokens,
630 output_tokens_details: OutputTokensDetails {
631 reasoning_tokens: 0,
632 },
633 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
634 },
635 };
636
637 Ok((res, false))
638 }
639 Err(wasmedge_wasi_nn::Error::BackendError(
640 wasmedge_wasi_nn::BackendError::PromptTooLong,
641 )) => {
642 #[cfg(feature = "logging")]
643 warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
644
645 // Retrieve the output.
646 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
647 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
648 let err_msg = format!(
649 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
650 );
651
652 #[cfg(feature = "logging")]
653 error!(target: "stdout", "{}", &err_msg);
654
655 LlamaCoreError::Operation(err_msg)
656 })?;
657
658 // post-process
659 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
660 let err_msg = format!("Failed to post-process the output. {e}");
661
662 #[cfg(feature = "logging")]
663 error!(target: "stdout", "{}", &err_msg);
664
665 LlamaCoreError::Operation(err_msg)
666 })?;
667
668 // retrieve the number of prompt and completion token
669 let token_info = get_token_info_by_graph(graph)?;
670
671 #[cfg(feature = "logging")]
672 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
673
674 let created = SystemTime::now()
675 .duration_since(std::time::UNIX_EPOCH)
676 .map_err(|e| {
677 let err_msg = format!("Failed to get the current time. Reason: {e}");
678
679 #[cfg(feature = "logging")]
680 error!(target: "stdout", "{}", &err_msg);
681
682 LlamaCoreError::Operation(err_msg)
683 })?;
684
685 let message = ResponseOutputItemOutputMessageContent::OutputText {
686 annotations: vec![],
687 text: message.to_string(),
688 ty: "output_text".to_string(),
689 logprobs: None,
690 };
691
692 let output_message = ResponseOutputItem::OutputMessage {
693 content: vec![message],
694 id: "msg_67ccd3acc8d48190a77525dc6de64b4104becb25c45c1d41".to_string(),
695 role: ChatCompletionRole::Assistant.to_string(),
696 status: "completed".to_string(),
697 ty: "message".to_string(),
698 };
699
700 let temperature = match &chat_request.temperature {
701 Some(t) => *t,
702 None => graph.metadata.temperature,
703 };
704
705 let top_p = match &chat_request.top_p {
706 Some(t) => *t,
707 None => graph.metadata.top_p,
708 };
709
710 let res = ResponseObject {
711 background: false,
712 conversation: None,
713 created_at: created.as_secs(),
714 error: None,
715 id: gen_response_id(),
716 incomplete_details: None,
717 instructions: None,
718 max_output_tokens: None,
719 max_tool_calls: None,
720 metadata: HashMap::new(),
721 model: graph.name().to_owned(),
722 object: "response".to_string(),
723 output: vec![output_message],
724 parallel_tool_calls: true,
725 previous_response_id: None,
726 safety_identifier: None,
727 status: "completed".to_string(),
728 temperature,
729 tool_choice: chat_request.tool_choice.clone(),
730 tools: chat_request.tools.clone(),
731 top_p,
732 truncation: None,
733 usage: Usage {
734 input_tokens: token_info.prompt_tokens,
735 input_tokens_details: InputTokensDetails { cached_tokens: 0 },
736 output_tokens: token_info.completion_tokens,
737 output_tokens_details: OutputTokensDetails {
738 reasoning_tokens: 0,
739 },
740 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
741 },
742 };
743
744 Ok((res, false))
745 }
746 Err(e) => {
747 let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
748
749 #[cfg(feature = "logging")]
750 error!(target: "stdout", "{}", &err_msg);
751
752 Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
753 }
754 }
755}
756
757// fn parse_tool_calls(
758// input: &str,
759// prompt_template: PromptTemplateType,
760// ) -> Result<ParseResult, LlamaCoreError> {
761// match prompt_template {
762// PromptTemplateType::MistralTool => match regex::Regex::new(r"\[\{.*?\}\]") {
763// Ok(re) => {
764// let mut values: Vec<serde_json::Value> = vec![];
765// for cap in re.captures_iter(input) {
766// let matched = &cap[0];
767
768// #[cfg(feature = "logging")]
769// info!(target: "stdout", "captured: {matched}");
770
771// match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
772// Ok(group) => values.extend(group),
773// Err(e) => {
774// let err_msg =
775// format!("Failed to deserialize generated tool calls. Reason: {e}");
776
777// #[cfg(feature = "logging")]
778// error!(target: "stdout", "{}", &err_msg);
779
780// return Err(LlamaCoreError::Operation(err_msg));
781// }
782// }
783// }
784
785// let mut tool_calls: Vec<ToolCall> = vec![];
786// for value in values.iter() {
787// let name = match value.get("name") {
788// Some(name) => name.to_string().replace("\"", ""),
789// None => {
790// let err_msg = format!(
791// "Failed to get the name of the function. Tool call: {value:?}"
792// );
793
794// #[cfg(feature = "logging")]
795// error!(target: "stdout", "{}", &err_msg);
796
797// return Err(LlamaCoreError::Operation(err_msg));
798// }
799// };
800
801// let arguments = match value.get("arguments") {
802// Some(arguments) => arguments.to_string(),
803// None => {
804// let err_msg = format!(
805// "Failed to get the arguments of the function. Tool call: {value:?}"
806// );
807
808// #[cfg(feature = "logging")]
809// error!(target: "stdout", "{}", &err_msg);
810
811// return Err(LlamaCoreError::Operation(err_msg));
812// }
813// };
814
815// let function = Function { name, arguments };
816
817// let tool_call = ToolCall {
818// id: "call_abc123".to_string(),
819// ty: "function".to_string(),
820// function,
821// };
822
823// tool_calls.push(tool_call);
824// }
825
826// let parsed = ParseResult {
827// raw: input.to_owned(),
828// content: None,
829// tool_calls,
830// };
831
832// #[cfg(feature = "logging")]
833// info!(target: "stdout", "parsed result: {parsed:?}");
834
835// Ok(parsed)
836// }
837// Err(e) => {
838// let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
839
840// #[cfg(feature = "logging")]
841// error!(target: "stdout", "{}", &err_msg);
842
843// Err(LlamaCoreError::Operation(err_msg))
844// }
845// },
846// PromptTemplateType::ChatMLTool => {
847// match regex::Regex::new(r"<tool_call>(.*?)</tool_call>") {
848// Ok(re) => {
849// let mut values: Vec<serde_json::Value> = vec![];
850// for cap in re.captures_iter(input) {
851// let matched = cap[1].replace("\\n", ""); // Remove "\\n" from the captured group
852
853// #[cfg(feature = "logging")]
854// info!(target: "stdout", "captured: {}", &matched);
855
856// match serde_json::from_str::<serde_json::Value>(&matched) {
857// Ok(value) => values.push(value),
858// Err(e) => {
859// let err_msg = format!(
860// "Failed to deserialize generated tool calls. Reason: {e}"
861// );
862
863// #[cfg(feature = "logging")]
864// error!(target: "stdout", "{}", &err_msg);
865
866// return Err(LlamaCoreError::Operation(err_msg));
867// }
868// }
869// }
870
871// let mut tool_calls: Vec<ToolCall> = vec![];
872// for value in values.iter() {
873// let name = match value.get("name") {
874// Some(name) => name.to_string().replace("\"", ""),
875// None => {
876// let err_msg = format!(
877// "Failed to get the name of the function. Tool call: {value:?}"
878// );
879
880// #[cfg(feature = "logging")]
881// error!(target: "stdout", "{}", &err_msg);
882
883// return Err(LlamaCoreError::Operation(err_msg));
884// }
885// };
886
887// let arguments = match value.get("arguments") {
888// Some(arguments) => arguments.to_string(),
889// None => {
890// let err_msg = format!(
891// "Failed to get the arguments of the function. Tool call: {value:?}"
892// );
893
894// #[cfg(feature = "logging")]
895// error!(target: "stdout", "{}", &err_msg);
896
897// return Err(LlamaCoreError::Operation(err_msg));
898// }
899// };
900
901// let function = Function { name, arguments };
902
903// let tool_call = ToolCall {
904// id: "call_abc123".to_string(),
905// ty: "function".to_string(),
906// function,
907// };
908
909// tool_calls.push(tool_call);
910// }
911
912// let parsed = ParseResult {
913// raw: input.to_owned(),
914// content: None,
915// tool_calls,
916// };
917
918// #[cfg(feature = "logging")]
919// info!(target: "stdout", "parsed result: {parsed:?}");
920
921// Ok(parsed)
922// }
923// Err(e) => {
924// let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
925
926// #[cfg(feature = "logging")]
927// error!(target: "stdout", "{}", &err_msg);
928
929// Err(LlamaCoreError::Operation(err_msg))
930// }
931// }
932// }
933// PromptTemplateType::GroqLlama3Tool => {
934// #[cfg(feature = "logging")]
935// info!(target: "stdout", "raw input: {input}");
936
937// match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
938// Ok(re) => {
939// let mut values: Vec<serde_json::Value> = vec![];
940// for cap in re.captures_iter(input) {
941// let matched = cap[1].trim();
942
943// #[cfg(feature = "logging")]
944// info!(target: "stdout", "captured: {matched}");
945
946// match serde_json::from_str::<serde_json::Value>(matched) {
947// Ok(value) => values.push(value),
948// Err(e) => {
949// let err_msg = format!(
950// "Failed to deserialize generated tool calls. Reason: {e}"
951// );
952
953// #[cfg(feature = "logging")]
954// error!(target: "stdout", "{}", &err_msg);
955
956// return Err(LlamaCoreError::Operation(err_msg));
957// }
958// }
959// }
960
961// let mut tool_calls: Vec<ToolCall> = vec![];
962// for value in values.iter() {
963// let name = match value.get("name") {
964// Some(name) => name.to_string().replace("\"", ""),
965// None => {
966// let err_msg = format!(
967// "Failed to get the name of the function. Tool call: {value:?}"
968// );
969
970// #[cfg(feature = "logging")]
971// error!(target: "stdout", "{}", &err_msg);
972
973// return Err(LlamaCoreError::Operation(err_msg));
974// }
975// };
976
977// let arguments = match value.get("arguments") {
978// Some(arguments) => {
979// if arguments.is_string() {
980// arguments.as_str().unwrap().to_string()
981// } else if arguments.is_object() {
982// let map = arguments.as_object().unwrap();
983
984// #[cfg(feature = "logging")]
985// info!(target: "stdout", "func arguments: {map:?}");
986
987// serde_json::to_string(map).unwrap()
988// } else {
989// serde_json::to_string(arguments).unwrap()
990// }
991// }
992// None => {
993// let err_msg = format!(
994// "Failed to get the arguments of the function. Tool call: {value:?}"
995// );
996
997// #[cfg(feature = "logging")]
998// error!(target: "stdout", "{}", &err_msg);
999
1000// return Err(LlamaCoreError::Operation(err_msg));
1001// }
1002// };
1003
1004// let function = Function { name, arguments };
1005
1006// let tool_call = ToolCall {
1007// id: "call_abc123".to_string(),
1008// ty: "function".to_string(),
1009// function,
1010// };
1011
1012// tool_calls.push(tool_call);
1013// }
1014
1015// let parsed = if tool_calls.is_empty() {
1016// ParseResult {
1017// raw: input.to_owned(),
1018// content: Some(input.to_owned()),
1019// tool_calls: vec![],
1020// }
1021// } else {
1022// ParseResult {
1023// raw: input.to_owned(),
1024// content: None,
1025// tool_calls,
1026// }
1027// };
1028
1029// #[cfg(feature = "logging")]
1030// info!(target: "stdout", "parsed result: {parsed:?}");
1031
1032// Ok(parsed)
1033// }
1034// Err(e) => {
1035// let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1036
1037// #[cfg(feature = "logging")]
1038// error!(target: "stdout", "{}", &err_msg);
1039
1040// Err(LlamaCoreError::Operation(err_msg))
1041// }
1042// }
1043// }
1044// PromptTemplateType::Llama3Tool => {
1045// #[cfg(feature = "logging")]
1046// info!(target: "stdout", "raw input: {input}");
1047
1048// let re = match regex::Regex::new(r"^\{(.|\r|\n)*\}$") {
1049// Ok(re) => re,
1050// Err(e) => {
1051// let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1052
1053// #[cfg(feature = "logging")]
1054// error!(target: "stdout", "{}", &err_msg);
1055
1056// return Err(LlamaCoreError::Operation(err_msg));
1057// }
1058// };
1059
1060// if re.is_match(input) {
1061// match serde_json::from_str::<serde_json::Value>(input) {
1062// Ok(value) => {
1063// let values: Vec<serde_json::Value> = vec![value];
1064
1065// let mut tool_calls: Vec<ToolCall> = vec![];
1066// for value in values.iter() {
1067// let name = match value.get("name") {
1068// Some(name) => name.to_string().replace("\"", ""),
1069// None => {
1070// let err_msg = format!(
1071// "Failed to get the name of the function. Tool call: {value:?}"
1072// );
1073
1074// #[cfg(feature = "logging")]
1075// error!(target: "stdout", "{}", &err_msg);
1076
1077// return Err(LlamaCoreError::Operation(err_msg));
1078// }
1079// };
1080
1081// let arguments = match value.get("parameters") {
1082// Some(arguments) => arguments.to_string(),
1083// None => {
1084// let err_msg = format!(
1085// "Failed to get the arguments of the function. Tool call: {value:?}"
1086// );
1087
1088// #[cfg(feature = "logging")]
1089// error!(target: "stdout", "{}", &err_msg);
1090
1091// return Err(LlamaCoreError::Operation(err_msg));
1092// }
1093// };
1094
1095// let function = Function { name, arguments };
1096
1097// let tool_call = ToolCall {
1098// id: "call_abc123".to_string(),
1099// ty: "function".to_string(),
1100// function,
1101// };
1102
1103// tool_calls.push(tool_call);
1104// }
1105
1106// let parsed = ParseResult {
1107// raw: input.to_owned(),
1108// content: None,
1109// tool_calls,
1110// };
1111
1112// #[cfg(feature = "logging")]
1113// info!(target: "stdout", "parsed result: {parsed:?}");
1114
1115// Ok(parsed)
1116// }
1117// Err(e) => {
1118// let err_msg =
1119// format!("Failed to deserialize generated tool calls. Reason: {e}");
1120
1121// #[cfg(feature = "logging")]
1122// error!(target: "stdout", "{}", &err_msg);
1123
1124// Err(LlamaCoreError::Operation(err_msg))
1125// }
1126// }
1127// } else {
1128// let parsed = ParseResult {
1129// raw: input.to_owned(),
1130// content: None,
1131// tool_calls: vec![],
1132// };
1133
1134// #[cfg(feature = "logging")]
1135// info!(target: "stdout", "parsed result: {parsed:?}");
1136
1137// Ok(parsed)
1138// }
1139// }
1140// PromptTemplateType::InternLM2Tool => {
1141// #[cfg(feature = "logging")]
1142// info!(target: "stdout", "raw input: {input}");
1143
1144// let blocks: Vec<&str> = input.trim().split("<|action_start|><|plugin|>").collect();
1145
1146// #[cfg(feature = "logging")]
1147// info!(target: "stdout", "blocks: {blocks:?}");
1148
1149// let mut tool_calls: Vec<ToolCall> = vec![];
1150// let mut content = String::new();
1151// for block in blocks {
1152// let block = block.trim();
1153// if !block.is_empty() {
1154// if block.ends_with("<|action_end|>") {
1155// let value = block.trim().trim_end_matches("<|action_end|>");
1156
1157// #[cfg(feature = "logging")]
1158// info!(target: "stdout", "tool call: {value}");
1159
1160// match serde_json::from_str::<serde_json::Value>(value) {
1161// Ok(value) => {
1162// let name = match value.get("name") {
1163// Some(name) => name.to_string().replace("\"", ""),
1164// None => {
1165// let err_msg = format!(
1166// "Failed to get the name of the function. Tool call: {value:?}"
1167// );
1168
1169// #[cfg(feature = "logging")]
1170// error!(target: "stdout", "{}", &err_msg);
1171
1172// return Err(LlamaCoreError::Operation(err_msg));
1173// }
1174// };
1175
1176// let arguments = match value.get("parameters") {
1177// Some(arguments) => arguments.to_string(),
1178// None => {
1179// let err_msg = format!(
1180// "Failed to get the arguments of the function. Tool call: {value:?}"
1181// );
1182
1183// #[cfg(feature = "logging")]
1184// error!(target: "stdout", "{}", &err_msg);
1185
1186// return Err(LlamaCoreError::Operation(err_msg));
1187// }
1188// };
1189
1190// let function = Function { name, arguments };
1191
1192// let tool_call = ToolCall {
1193// id: "call_abc123".to_string(),
1194// ty: "function".to_string(),
1195// function,
1196// };
1197
1198// tool_calls.push(tool_call);
1199// }
1200// Err(e) => {
1201// let err_msg = format!(
1202// "Failed to deserialize generated tool calls. Reason: {e}"
1203// );
1204
1205// #[cfg(feature = "logging")]
1206// error!(target: "stdout", "{}", &err_msg);
1207
1208// return Err(LlamaCoreError::Operation(err_msg));
1209// }
1210// }
1211// } else {
1212// content.push_str(block);
1213// content.push('\n');
1214// }
1215// }
1216// }
1217
1218// let parsed = match content.is_empty() {
1219// true => ParseResult {
1220// raw: input.to_owned(),
1221// content: None,
1222// tool_calls,
1223// },
1224// false => ParseResult {
1225// raw: input.to_owned(),
1226// content: Some(content.trim().to_owned()),
1227// tool_calls,
1228// },
1229// };
1230
1231// #[cfg(feature = "logging")]
1232// info!(target: "stdout", "parsed result: {parsed:?}");
1233
1234// Ok(parsed)
1235// }
1236// PromptTemplateType::NemotronTool => {
1237// #[cfg(feature = "logging")]
1238// info!(target: "stdout", "raw input: {input}");
1239
1240// match regex::Regex::new(r"(?s)<toolcall>\s*(.*?)\s*</toolcall>") {
1241// Ok(re) => {
1242// let mut values: Vec<serde_json::Value> = vec![];
1243// for cap in re.captures_iter(input) {
1244// #[cfg(feature = "logging")]
1245// info!(target: "stdout", "captured: {}", &cap[0]);
1246
1247// #[cfg(feature = "logging")]
1248// info!(target: "stdout", "extracted: {}", &cap[1]);
1249
1250// let matched = cap[1].trim();
1251
1252// #[cfg(feature = "logging")]
1253// info!(target: "stdout", "captured: {matched}");
1254
1255// match serde_json::from_str::<serde_json::Value>(matched) {
1256// Ok(value) => values.push(value),
1257// Err(e) => {
1258// let err_msg = format!(
1259// "Failed to deserialize generated tool calls. Reason: {e}"
1260// );
1261
1262// #[cfg(feature = "logging")]
1263// error!(target: "stdout", "{}", &err_msg);
1264
1265// return Err(LlamaCoreError::Operation(err_msg));
1266// }
1267// }
1268// }
1269
1270// let mut tool_calls: Vec<ToolCall> = vec![];
1271// for value in values.iter() {
1272// let name = match value.get("name") {
1273// Some(name) => name.to_string().replace("\"", ""),
1274// None => {
1275// let err_msg = format!(
1276// "Failed to get the name of the function. Tool call: {value:?}"
1277// );
1278
1279// #[cfg(feature = "logging")]
1280// error!(target: "stdout", "{}", &err_msg);
1281
1282// return Err(LlamaCoreError::Operation(err_msg));
1283// }
1284// };
1285
1286// let arguments = match value.get("arguments") {
1287// Some(arguments) => arguments.to_string(),
1288// None => {
1289// let err_msg = format!(
1290// "Failed to get the arguments of the function. Tool call: {value:?}"
1291// );
1292
1293// #[cfg(feature = "logging")]
1294// error!(target: "stdout", "{}", &err_msg);
1295
1296// return Err(LlamaCoreError::Operation(err_msg));
1297// }
1298// };
1299
1300// let function = Function { name, arguments };
1301
1302// let tool_call = ToolCall {
1303// id: "call_abc123".to_string(),
1304// ty: "function".to_string(),
1305// function,
1306// };
1307
1308// tool_calls.push(tool_call);
1309// }
1310
1311// let parsed = ParseResult {
1312// raw: input.to_owned(),
1313// content: None,
1314// tool_calls,
1315// };
1316
1317// #[cfg(feature = "logging")]
1318// info!(target: "stdout", "parsed result: {parsed:?}");
1319
1320// Ok(parsed)
1321// }
1322// Err(e) => {
1323// let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1324
1325// #[cfg(feature = "logging")]
1326// error!(target: "stdout", "{}", &err_msg);
1327
1328// Err(LlamaCoreError::Operation(err_msg))
1329// }
1330// }
1331// }
1332// PromptTemplateType::FunctionaryV32 => {
1333// #[cfg(feature = "logging")]
1334// info!(target: "stdout", "raw input: {input}");
1335
1336// match regex::Regex::new(r">>>\s*(\w+)\s*\{(.*)\}<\|eot_id\|>") {
1337// Ok(re) => {
1338// let mut tool_calls: Vec<ToolCall> = vec![];
1339// for cap in re.captures_iter(input) {
1340// #[cfg(feature = "logging")]
1341// info!(target: "stdout", "func_name: {}", &cap[1]);
1342
1343// #[cfg(feature = "logging")]
1344// info!(target: "stdout", "arguments: {}", &cap[2]);
1345
1346// let tool_call = ToolCall {
1347// id: "call_abc123".to_string(),
1348// ty: "function".to_string(),
1349// function: Function {
1350// name: cap[1].to_string(),
1351// arguments: cap[2].to_string(),
1352// },
1353// };
1354
1355// tool_calls.push(tool_call);
1356// }
1357
1358// let parsed = ParseResult {
1359// raw: input.to_owned(),
1360// content: None,
1361// tool_calls,
1362// };
1363
1364// #[cfg(feature = "logging")]
1365// info!(target: "stdout", "parsed result: {parsed:?}");
1366
1367// Ok(parsed)
1368// }
1369// Err(e) => {
1370// let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1371
1372// #[cfg(feature = "logging")]
1373// warn!(target: "stdout", "{}", &warn_msg);
1374
1375// Ok(ParseResult {
1376// raw: input.to_owned(),
1377// content: None,
1378// tool_calls: vec![],
1379// })
1380// }
1381// }
1382// }
1383// PromptTemplateType::FunctionaryV31 => {
1384// #[cfg(feature = "logging")]
1385// info!(target: "stdout", "raw input: {input}");
1386
1387// match regex::Regex::new(r"<function=(\w+)>\s*(\{.*?\})</function>") {
1388// Ok(re) => {
1389// let mut tool_calls: Vec<ToolCall> = vec![];
1390// for cap in re.captures_iter(input) {
1391// #[cfg(feature = "logging")]
1392// info!(target: "stdout", "func_name: {}", &cap[1]);
1393
1394// #[cfg(feature = "logging")]
1395// info!(target: "stdout", "arguments: {}", &cap[2]);
1396
1397// let tool_call = ToolCall {
1398// id: "call_abc123".to_string(),
1399// ty: "function".to_string(),
1400// function: Function {
1401// name: cap[1].to_string(),
1402// arguments: cap[2].to_string(),
1403// },
1404// };
1405
1406// tool_calls.push(tool_call);
1407// }
1408
1409// let parsed = ParseResult {
1410// raw: input.to_owned(),
1411// content: None,
1412// tool_calls,
1413// };
1414
1415// #[cfg(feature = "logging")]
1416// info!(target: "stdout", "parsed result: {parsed:?}");
1417
1418// Ok(parsed)
1419// }
1420// Err(e) => {
1421// let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1422
1423// #[cfg(feature = "logging")]
1424// warn!(target: "stdout", "{}", &warn_msg);
1425
1426// Ok(ParseResult {
1427// raw: input.to_owned(),
1428// content: None,
1429// tool_calls: vec![],
1430// })
1431// }
1432// }
1433// }
1434// PromptTemplateType::MistralSmallTool => {
1435// #[cfg(feature = "logging")]
1436// info!(target: "stdout", "raw input: {input}");
1437
1438// match regex::Regex::new(r"\[TOOL_CALLS\]\s*(\[(.*?)\])") {
1439// Ok(re) => {
1440// let mut values: Vec<serde_json::Value> = vec![];
1441// if let Some(cap) = re.captures(input) {
1442// let matched = cap[1].trim();
1443
1444// #[cfg(feature = "logging")]
1445// info!(target: "stdout", "captured: {matched}");
1446
1447// match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1448// Ok(vals) => values = vals,
1449// Err(e) => {
1450// let err_msg = format!(
1451// "Failed to deserialize generated tool calls. Reason: {e}"
1452// );
1453
1454// #[cfg(feature = "logging")]
1455// error!(target: "stdout", "{}", &err_msg);
1456
1457// return Err(LlamaCoreError::Operation(err_msg));
1458// }
1459// }
1460// };
1461
1462// let mut tool_calls: Vec<ToolCall> = vec![];
1463// for value in values.iter() {
1464// if let Some(object_map) = value.as_object() {
1465// if object_map.contains_key("function") {
1466// let mut function = Function {
1467// name: String::new(),
1468// arguments: String::new(),
1469// };
1470
1471// let value = object_map.get("function").unwrap();
1472// let func_map = value.as_object().unwrap();
1473// if func_map.contains_key("name") {
1474// let func_name = func_map.get("name").unwrap().as_str().unwrap();
1475// println!("Function name: {func_name:?}");
1476
1477// function.name = func_name.to_string();
1478// }
1479// if func_map.contains_key("arguments") {
1480// let args = func_map.get("arguments").unwrap();
1481// let arguments = args.to_string();
1482// println!("Arguments: {arguments:?}");
1483
1484// function.arguments = arguments;
1485// }
1486
1487// let tool_call = ToolCall {
1488// id: "call_abc123".to_string(),
1489// ty: "function".to_string(),
1490// function,
1491// };
1492
1493// tool_calls.push(tool_call);
1494// } else if object_map.contains_key("name") {
1495// let mut function = Function {
1496// name: String::new(),
1497// arguments: String::new(),
1498// };
1499
1500// let name = object_map.get("name").unwrap().as_str().unwrap();
1501// println!("name: {name:?}");
1502// function.name = name.to_string();
1503
1504// if object_map.contains_key("arguments") {
1505// let args = object_map.get("arguments").unwrap();
1506// let arguments = args.to_string();
1507// println!("Arguments: {arguments:?}");
1508
1509// function.arguments = arguments;
1510// }
1511
1512// let tool_call = ToolCall {
1513// id: "call_abc123".to_string(),
1514// ty: "function".to_string(),
1515// function,
1516// };
1517
1518// tool_calls.push(tool_call);
1519// }
1520// }
1521// }
1522
1523// let parsed = ParseResult {
1524// raw: input.to_owned(),
1525// content: None,
1526// tool_calls,
1527// };
1528
1529// #[cfg(feature = "logging")]
1530// info!(target: "stdout", "parsed result: {parsed:?}");
1531
1532// Ok(parsed)
1533// }
1534// Err(e) => {
1535// let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1536
1537// #[cfg(feature = "logging")]
1538// error!(target: "stdout", "{}", &err_msg);
1539
1540// Err(LlamaCoreError::Operation(err_msg))
1541// }
1542// }
1543// }
1544// PromptTemplateType::Llama4Chat => {
1545// #[cfg(feature = "logging")]
1546// info!(target: "stdout", "raw input: {input:?}");
1547
1548// let mut tool_calls: Vec<ToolCall> = vec![];
1549// if let Ok(value) = serde_json::from_str::<serde_json::Value>(input) {
1550// match value.as_object() {
1551// Some(object_map) => {
1552// #[cfg(feature = "logging")]
1553// debug!(target: "stdout", "object_map: {object_map:?}");
1554
1555// // parse function name
1556// if object_map.contains_key("name") {
1557// let name = object_map.get("name").unwrap().as_str().unwrap();
1558
1559// #[cfg(feature = "logging")]
1560// debug!(target: "stdout", "name: {name:?}");
1561
1562// let mut function = Function {
1563// name: name.to_string(),
1564// arguments: String::new(),
1565// };
1566
1567// // parse function arguments
1568// if object_map.contains_key("parameters") {
1569// let args = object_map.get("parameters").unwrap();
1570// let arguments = args.to_string();
1571
1572// #[cfg(feature = "logging")]
1573// debug!(target: "stdout", "arguments: {:?}", &arguments);
1574
1575// function.arguments = arguments;
1576// }
1577
1578// tool_calls.push(ToolCall {
1579// id: "call_abc123".to_string(),
1580// ty: "function".to_string(),
1581// function,
1582// });
1583// } else {
1584// let err_msg = format!(
1585// "Failed to get the name of the function. raw input: {input:?}"
1586// );
1587
1588// #[cfg(feature = "logging")]
1589// error!(target: "stdout", "{}", &err_msg);
1590
1591// return Err(LlamaCoreError::Operation(err_msg));
1592// }
1593// }
1594// None => {
1595// let err_msg = format!("Failed to parse the JSON string. JSON: {input}");
1596
1597// #[cfg(feature = "logging")]
1598// error!(target: "stdout", "{}", &err_msg);
1599
1600// return Err(LlamaCoreError::Operation(err_msg));
1601// }
1602// }
1603// }
1604
1605// let parsed = ParseResult {
1606// raw: input.to_owned(),
1607// content: None,
1608// tool_calls,
1609// };
1610
1611// #[cfg(feature = "logging")]
1612// info!(target: "stdout", "parsed result: {parsed:?}");
1613
1614// Ok(parsed)
1615// }
1616// PromptTemplateType::Qwen3NoThink | PromptTemplateType::Smol3NoThink => {
1617// #[cfg(feature = "logging")]
1618// info!(target: "stdout", "raw input: {input:?}");
1619
1620// match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
1621// Ok(re) => {
1622// let mut values: Vec<serde_json::Value> = vec![];
1623// for cap in re.captures_iter(input) {
1624// let mut matched = cap[1].trim();
1625
1626// if matched.starts_with("\\n") {
1627// matched = matched.trim_start_matches("\\n");
1628// }
1629
1630// if matched.ends_with("\\n") {
1631// matched = matched.trim_end_matches("\\n");
1632// }
1633
1634// #[cfg(feature = "logging")]
1635// info!(target: "stdout", "captured: {matched:#?}");
1636
1637// if !matched.is_empty() {
1638// match serde_json::from_str::<serde_json::Value>(matched) {
1639// Ok(value) => values.push(value),
1640// Err(e) => {
1641// let err_msg = format!(
1642// "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
1643// );
1644
1645// #[cfg(feature = "logging")]
1646// error!(target: "stdout", "{}", &err_msg);
1647
1648// return Err(LlamaCoreError::Operation(err_msg));
1649// }
1650// }
1651// }
1652// }
1653
1654// let mut tool_calls: Vec<ToolCall> = vec![];
1655// for value in values.iter() {
1656// let name = match value.get("name") {
1657// Some(name) => name.to_string().replace("\"", ""),
1658// None => {
1659// let err_msg = format!(
1660// "Failed to get the name of the function. Tool call: {value:?}"
1661// );
1662
1663// #[cfg(feature = "logging")]
1664// error!(target: "stdout", "{}", &err_msg);
1665
1666// return Err(LlamaCoreError::Operation(err_msg));
1667// }
1668// };
1669
1670// let arguments = match value.get("arguments") {
1671// Some(arguments) => {
1672// if arguments.is_string() {
1673// arguments.as_str().unwrap().to_string()
1674// } else if arguments.is_object() {
1675// let map = arguments.as_object().unwrap();
1676
1677// #[cfg(feature = "logging")]
1678// info!(target: "stdout", "func arguments: {map:?}");
1679
1680// serde_json::to_string(map).unwrap()
1681// } else {
1682// serde_json::to_string(arguments).unwrap()
1683// }
1684// }
1685// None => {
1686// let err_msg = format!(
1687// "Failed to get the arguments of the function. Tool call: {value:?}"
1688// );
1689
1690// #[cfg(feature = "logging")]
1691// error!(target: "stdout", "{}", &err_msg);
1692
1693// return Err(LlamaCoreError::Operation(err_msg));
1694// }
1695// };
1696
1697// let function = Function { name, arguments };
1698
1699// let tool_call = ToolCall {
1700// id: "call_abc123".to_string(),
1701// ty: "function".to_string(),
1702// function,
1703// };
1704
1705// tool_calls.push(tool_call);
1706// }
1707
1708// let parsed = if tool_calls.is_empty() {
1709// ParseResult {
1710// raw: input.to_owned(),
1711// content: Some(input.to_owned()),
1712// tool_calls: vec![],
1713// }
1714// } else {
1715// ParseResult {
1716// raw: input.to_owned(),
1717// content: None,
1718// tool_calls,
1719// }
1720// };
1721
1722// #[cfg(feature = "logging")]
1723// info!(target: "stdout", "parsed result: {parsed:?}");
1724
1725// Ok(parsed)
1726// }
1727// Err(e) => {
1728// let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1729
1730// #[cfg(feature = "logging")]
1731// error!(target: "stdout", "{}", &err_msg);
1732
1733// Err(LlamaCoreError::Operation(err_msg))
1734// }
1735// }
1736// }
1737// PromptTemplateType::Gemma3 => {
1738// #[cfg(feature = "logging")]
1739// info!(target: "stdout", "raw input: {input:?}");
1740
1741// match regex::Regex::new(r"(?s)```json\s*(.*?)\s*```") {
1742// Ok(re) => {
1743// let mut values: Vec<serde_json::Value> = vec![];
1744// for cap in re.captures_iter(input) {
1745// let mut matched = cap[1].trim();
1746
1747// if matched.starts_with("\\n") {
1748// matched = matched.trim_start_matches("\\n");
1749// }
1750
1751// if matched.ends_with("\\n") {
1752// matched = matched.trim_end_matches("\\n");
1753// }
1754
1755// #[cfg(feature = "logging")]
1756// info!(target: "stdout", "captured: {matched:#?}");
1757
1758// if !matched.is_empty() {
1759// match serde_json::from_str::<serde_json::Value>(matched) {
1760// Ok(value) => values.push(value),
1761// Err(e) => {
1762// let err_msg = format!(
1763// "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
1764// );
1765
1766// #[cfg(feature = "logging")]
1767// error!(target: "stdout", "{}", &err_msg);
1768
1769// return Err(LlamaCoreError::Operation(err_msg));
1770// }
1771// }
1772// }
1773// }
1774
1775// let mut tool_calls: Vec<ToolCall> = vec![];
1776// for value in values.iter() {
1777// let name = match value.get("name") {
1778// Some(name) => name.to_string().replace("\"", ""),
1779// None => {
1780// let err_msg = format!(
1781// "Failed to get the name of the function. Tool call: {value:?}"
1782// );
1783
1784// #[cfg(feature = "logging")]
1785// error!(target: "stdout", "{}", &err_msg);
1786
1787// return Err(LlamaCoreError::Operation(err_msg));
1788// }
1789// };
1790
1791// let arguments = match value.get("arguments") {
1792// Some(arguments) => {
1793// if arguments.is_string() {
1794// arguments.as_str().unwrap().to_string()
1795// } else if arguments.is_object() {
1796// let map = arguments.as_object().unwrap();
1797
1798// #[cfg(feature = "logging")]
1799// info!(target: "stdout", "func arguments: {map:?}");
1800
1801// serde_json::to_string(map).unwrap()
1802// } else {
1803// serde_json::to_string(arguments).unwrap()
1804// }
1805// }
1806// None => {
1807// let err_msg = format!(
1808// "Failed to get the arguments of the function. Tool call: {value:?}"
1809// );
1810
1811// #[cfg(feature = "logging")]
1812// error!(target: "stdout", "{}", &err_msg);
1813
1814// return Err(LlamaCoreError::Operation(err_msg));
1815// }
1816// };
1817
1818// let function = Function { name, arguments };
1819
1820// let tool_call = ToolCall {
1821// id: "call_abc123".to_string(),
1822// ty: "function".to_string(),
1823// function,
1824// };
1825
1826// tool_calls.push(tool_call);
1827// }
1828
1829// let parsed = if tool_calls.is_empty() {
1830// ParseResult {
1831// raw: input.to_owned(),
1832// content: Some(input.to_owned()),
1833// tool_calls: vec![],
1834// }
1835// } else {
1836// ParseResult {
1837// raw: input.to_owned(),
1838// content: None,
1839// tool_calls,
1840// }
1841// };
1842
1843// #[cfg(feature = "logging")]
1844// info!(target: "stdout", "parsed result: {parsed:?}");
1845
1846// Ok(parsed)
1847// }
1848// Err(e) => {
1849// let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1850
1851// #[cfg(feature = "logging")]
1852// error!(target: "stdout", "{}", &err_msg);
1853
1854// Err(LlamaCoreError::Operation(err_msg))
1855// }
1856// }
1857// }
1858// PromptTemplateType::GptOss => {
1859// #[cfg(feature = "logging")]
1860// info!(target: "stdout", "raw input: {input:?}");
1861
1862// // Match strings ending with: <|channel|>commentary to=functions.xxxxx <|constrain|>json<|message|>yyyyy<|call|>
1863// match regex::Regex::new(
1864// r"<\|channel\|>commentary to=functions\.([^<\s]+)\s*<\|constrain\|>json<\|message\|>([^<]*)<\|call\|>$",
1865// ) {
1866// Ok(re) => {
1867// if let Some(cap) = re.captures(input) {
1868// let function_name = cap[1].trim();
1869// let arguments = cap[2].trim();
1870
1871// #[cfg(feature = "logging")]
1872// info!(target: "stdout", "extracted function_name: {function_name}, arguments: {arguments}");
1873
1874// let function = Function {
1875// name: function_name.to_string(),
1876// arguments: arguments.to_string(),
1877// };
1878
1879// let tool_call = ToolCall {
1880// id: "call_abc123".to_string(),
1881// ty: "function".to_string(),
1882// function,
1883// };
1884
1885// let parsed = ParseResult {
1886// raw: input.to_owned(),
1887// content: None,
1888// tool_calls: vec![tool_call],
1889// };
1890
1891// #[cfg(feature = "logging")]
1892// info!(target: "stdout", "parsed result: {parsed:?}");
1893
1894// Ok(parsed)
1895// } else {
1896// match regex::Regex::new(r"(?s)```json\s*(.*?)\s*```") {
1897// Ok(re) => {
1898// let mut values: Vec<serde_json::Value> = vec![];
1899// for cap in re.captures_iter(input) {
1900// let mut matched = cap[1].trim();
1901
1902// if matched.starts_with("\\n") {
1903// matched = matched.trim_start_matches("\\n");
1904// }
1905
1906// if matched.ends_with("\\n") {
1907// matched = matched.trim_end_matches("\\n");
1908// }
1909
1910// #[cfg(feature = "logging")]
1911// info!(target: "stdout", "captured: {matched:#?}");
1912
1913// if !matched.is_empty() {
1914// match serde_json::from_str::<serde_json::Value>(matched) {
1915// Ok(value) => values.push(value),
1916// Err(e) => {
1917// let err_msg = format!(
1918// "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
1919// );
1920
1921// #[cfg(feature = "logging")]
1922// error!(target: "stdout", "{}", &err_msg);
1923
1924// return Err(LlamaCoreError::Operation(err_msg));
1925// }
1926// }
1927// }
1928// }
1929
1930// let mut tool_calls: Vec<ToolCall> = vec![];
1931// for value in values.iter() {
1932// let name = match value.get("name") {
1933// Some(name) => name.to_string().replace("\"", ""),
1934// None => {
1935// let err_msg = format!(
1936// "Failed to get the name of the function. Tool call: {value:?}"
1937// );
1938
1939// #[cfg(feature = "logging")]
1940// error!(target: "stdout", "{}", &err_msg);
1941
1942// return Err(LlamaCoreError::Operation(err_msg));
1943// }
1944// };
1945
1946// let arguments = match value.get("arguments") {
1947// Some(arguments) => {
1948// if arguments.is_string() {
1949// arguments.as_str().unwrap().to_string()
1950// } else if arguments.is_object() {
1951// let map = arguments.as_object().unwrap();
1952
1953// #[cfg(feature = "logging")]
1954// info!(target: "stdout", "func arguments: {map:?}");
1955
1956// serde_json::to_string(map).unwrap()
1957// } else {
1958// serde_json::to_string(arguments).unwrap()
1959// }
1960// }
1961// None => {
1962// let err_msg = format!(
1963// "Failed to get the arguments of the function. Tool call: {value:?}"
1964// );
1965
1966// #[cfg(feature = "logging")]
1967// error!(target: "stdout", "{}", &err_msg);
1968
1969// return Err(LlamaCoreError::Operation(err_msg));
1970// }
1971// };
1972
1973// let function = Function { name, arguments };
1974
1975// let tool_call = ToolCall {
1976// id: "call_abc123".to_string(),
1977// ty: "function".to_string(),
1978// function,
1979// };
1980
1981// tool_calls.push(tool_call);
1982// }
1983
1984// let parsed = if tool_calls.is_empty() {
1985// ParseResult {
1986// raw: input.to_owned(),
1987// content: Some(input.to_owned()),
1988// tool_calls: vec![],
1989// }
1990// } else {
1991// ParseResult {
1992// raw: input.to_owned(),
1993// content: Some(input.to_owned()),
1994// tool_calls,
1995// }
1996// };
1997
1998// #[cfg(feature = "logging")]
1999// info!(target: "stdout", "parsed result: {parsed:?}");
2000
2001// Ok(parsed)
2002// }
2003// Err(e) => {
2004// let err_msg =
2005// format!("Failed to create a regex pattern. Reason: {e}");
2006
2007// #[cfg(feature = "logging")]
2008// error!(target: "stdout", "{}", &err_msg);
2009
2010// Err(LlamaCoreError::Operation(err_msg))
2011// }
2012// }
2013// }
2014// }
2015// Err(e) => {
2016// let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2017
2018// #[cfg(feature = "logging")]
2019// error!(target: "stdout", "{}", &err_msg);
2020
2021// Err(LlamaCoreError::Operation(err_msg))
2022// }
2023// }
2024// }
2025// PromptTemplateType::Qwen3Agent => {
2026// #[cfg(feature = "logging")]
2027// info!(target: "stdout", "Raw input to tool call parser: {input:?}");
2028
2029// // detect <action> tags
2030// match regex::Regex::new(r"<action>(.*?)</action>")
2031// .unwrap()
2032// .captures(input)
2033// {
2034// Some(captures) => {
2035// let action = captures.get(1).unwrap().as_str();
2036
2037// #[cfg(feature = "logging")]
2038// info!(target: "stdout", "Action: {action}");
2039
2040// match serde_json::from_str::<serde_json::Value>(action) {
2041// Ok(value) => {
2042// let name = match value.get("name") {
2043// Some(name) => name.to_string().replace("\"", ""),
2044// None => {
2045// let err_msg = format!(
2046// "Failed to get the name of the function. Tool call: {value:?}"
2047// );
2048
2049// #[cfg(feature = "logging")]
2050// error!(target: "stdout", "{}", &err_msg);
2051
2052// return Err(LlamaCoreError::Operation(err_msg));
2053// }
2054// };
2055
2056// let arguments = match value.get("arguments") {
2057// Some(arguments) => {
2058// if arguments.is_string() {
2059// arguments.as_str().unwrap().to_string()
2060// } else if arguments.is_object() {
2061// let map = arguments.as_object().unwrap();
2062
2063// #[cfg(feature = "logging")]
2064// info!(target: "stdout", "func arguments: {map:?}");
2065
2066// serde_json::to_string(map).unwrap()
2067// } else {
2068// serde_json::to_string(arguments).unwrap()
2069// }
2070// }
2071// None => {
2072// let err_msg = format!(
2073// "Failed to get the arguments of the function. Tool call: {value:?}"
2074// );
2075
2076// #[cfg(feature = "logging")]
2077// error!(target: "stdout", "{}", &err_msg);
2078
2079// return Err(LlamaCoreError::Operation(err_msg));
2080// }
2081// };
2082
2083// let function = Function { name, arguments };
2084
2085// let tool_call = ToolCall {
2086// id: "call_abc123".to_string(),
2087// ty: "function".to_string(),
2088// function,
2089// };
2090
2091// Ok(ParseResult {
2092// raw: input.to_owned(),
2093// content: Some(input.to_owned()),
2094// tool_calls: vec![tool_call],
2095// })
2096// }
2097// Err(e) => {
2098// let err_msg = format!(
2099// "Failed to deserialize generated tool calls: {action:#?}. Reason: {e}"
2100// );
2101
2102// #[cfg(feature = "logging")]
2103// error!(target: "stdout", "{}", &err_msg);
2104
2105// Err(LlamaCoreError::Operation(err_msg))
2106// }
2107// }
2108// }
2109// None => match input.contains("<final_answer>") {
2110// true => Ok(ParseResult {
2111// raw: input.to_owned(),
2112// content: Some(input.to_owned()),
2113// tool_calls: vec![],
2114// }),
2115// false => {
2116// let content = format!("<final_answer>{}</final_answer>", input.trim());
2117
2118// Ok(ParseResult {
2119// raw: input.to_owned(),
2120// content: Some(content),
2121// tool_calls: vec![],
2122// })
2123// }
2124// },
2125// }
2126// }
2127// PromptTemplateType::SeedOssNoThink | PromptTemplateType::SeedOssThink => {
2128// #[cfg(feature = "logging")]
2129// info!(target: "stdout", "Raw input to tool call parser: {input:?}");
2130
2131// match regex::Regex::new(r"```json\n([\s\S]*?)\n") {
2132// Ok(re) => {
2133// let mut values: Vec<serde_json::Value> = vec![];
2134// for cap in re.captures_iter(input) {
2135// let mut matched = cap[1].trim();
2136
2137// if matched.starts_with("\\n") {
2138// matched = matched.trim_start_matches("\\n");
2139// }
2140
2141// if matched.ends_with("\\n") {
2142// matched = matched.trim_end_matches("\\n");
2143// }
2144
2145// #[cfg(feature = "logging")]
2146// info!(target: "stdout", "captured: {matched:#?}");
2147
2148// if !matched.is_empty() {
2149// match serde_json::from_str::<serde_json::Value>(matched) {
2150// Ok(value) => values.push(value),
2151// Err(e) => {
2152// let err_msg = format!(
2153// "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
2154// );
2155
2156// #[cfg(feature = "logging")]
2157// error!(target: "stdout", "{}", &err_msg);
2158
2159// return Err(LlamaCoreError::Operation(err_msg));
2160// }
2161// }
2162// }
2163// }
2164
2165// let mut tool_calls: Vec<ToolCall> = vec![];
2166// for value in values.iter() {
2167// let name = match value.get("name") {
2168// Some(name) => name.to_string().replace("\"", ""),
2169// None => {
2170// let err_msg = format!(
2171// "Failed to get the name of the function. Tool call: {value:?}"
2172// );
2173
2174// #[cfg(feature = "logging")]
2175// error!(target: "stdout", "{}", &err_msg);
2176
2177// return Err(LlamaCoreError::Operation(err_msg));
2178// }
2179// };
2180
2181// let arguments = match value.get("arguments") {
2182// Some(arguments) => {
2183// if arguments.is_string() {
2184// arguments.as_str().unwrap().to_string()
2185// } else if arguments.is_object() {
2186// let map = arguments.as_object().unwrap();
2187
2188// #[cfg(feature = "logging")]
2189// info!(target: "stdout", "func arguments: {map:?}");
2190
2191// serde_json::to_string(map).unwrap()
2192// } else {
2193// serde_json::to_string(arguments).unwrap()
2194// }
2195// }
2196// None => {
2197// let err_msg = format!(
2198// "Failed to get the arguments of the function. Tool call: {value:?}"
2199// );
2200
2201// #[cfg(feature = "logging")]
2202// error!(target: "stdout", "{}", &err_msg);
2203
2204// return Err(LlamaCoreError::Operation(err_msg));
2205// }
2206// };
2207
2208// let function = Function { name, arguments };
2209
2210// let tool_call = ToolCall {
2211// id: "call_abc123".to_string(),
2212// ty: "function".to_string(),
2213// function,
2214// };
2215
2216// tool_calls.push(tool_call);
2217// }
2218
2219// let parsed = if tool_calls.is_empty() {
2220// ParseResult {
2221// raw: input.to_owned(),
2222// content: Some(input.to_owned()),
2223// tool_calls: vec![],
2224// }
2225// } else {
2226// ParseResult {
2227// raw: input.to_owned(),
2228// content: None,
2229// tool_calls,
2230// }
2231// };
2232
2233// #[cfg(feature = "logging")]
2234// info!(target: "stdout", "parsed result: {parsed:?}");
2235
2236// Ok(parsed)
2237// }
2238// Err(e) => {
2239// let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2240
2241// #[cfg(feature = "logging")]
2242// error!(target: "stdout", "{}", &err_msg);
2243
2244// Err(LlamaCoreError::Operation(err_msg))
2245// }
2246// }
2247// }
2248// _ => {
2249// let err_msg = format!(
2250// "The tool use is only supported for prompt templates: {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, and {}.",
2251// PromptTemplateType::MistralTool,
2252// PromptTemplateType::ChatMLTool,
2253// PromptTemplateType::GroqLlama3Tool,
2254// PromptTemplateType::Llama3Tool,
2255// PromptTemplateType::InternLM2Tool,
2256// PromptTemplateType::NemotronTool,
2257// PromptTemplateType::FunctionaryV32,
2258// PromptTemplateType::MistralSmallTool,
2259// PromptTemplateType::Llama4Chat,
2260// PromptTemplateType::Qwen3NoThink,
2261// PromptTemplateType::Smol3NoThink,
2262// PromptTemplateType::Gemma3,
2263// PromptTemplateType::GptOss,
2264// PromptTemplateType::Qwen3Agent,
2265// PromptTemplateType::SeedOssNoThink,
2266// PromptTemplateType::SeedOssThink
2267// );
2268
2269// #[cfg(feature = "logging")]
2270// error!(target: "stdout", "{}", &err_msg);
2271
2272// Err(LlamaCoreError::Operation(err_msg))
2273// }
2274// }
2275// }
2276
2277fn check_model_metadata(
2278 chat_request: &RequestOfModelResponse,
2279) -> Result<GgmlMetadata, LlamaCoreError> {
2280 let mut should_update = false;
2281
2282 let mut metadata = get_model_metadata(chat_request.model.as_ref())?;
2283
2284 // TODO: check if necessary to update `image`
2285 // if metadata.prompt_template.is_image_supported() {
2286 // if let Some(ChatCompletionRequestMessage::User(user_message)) = chat_request.messages.last()
2287 // {
2288 // if let ChatCompletionUserMessageContent::Parts(parts) = user_message.content() {
2289 // for part in parts {
2290 // if let ContentPart::Image(image_part) = part {
2291 // let image = image_part.image();
2292
2293 // if image.is_url() {
2294 // let err_msg = "The image is provided in URL format. Only base64 format is supported.".to_string();
2295
2296 // #[cfg(feature = "logging")]
2297 // error!(target: "stdout", "{}", &err_msg);
2298
2299 // return Err(LlamaCoreError::Operation(err_msg));
2300 // } else {
2301 // #[cfg(feature = "logging")]
2302 // info!(target: "stdout", "The image is provided in base64 format.");
2303
2304 // // TODO: now only support a single image
2305
2306 // break;
2307 // }
2308 // }
2309 // }
2310 // }
2311 // }
2312 // }
2313
2314 // check if necessary to update temperature
2315 if let Some(temp) = chat_request.temperature {
2316 if metadata.temperature != temp {
2317 // update temperature
2318 metadata.temperature = temp;
2319
2320 if !should_update {
2321 should_update = true;
2322 }
2323 }
2324 }
2325
2326 // check if necessary to update top_p
2327 if let Some(top_p) = chat_request.top_p {
2328 if metadata.top_p != top_p {
2329 // update top_p
2330 metadata.top_p = top_p;
2331
2332 if !should_update {
2333 should_update = true;
2334 }
2335 }
2336 }
2337
2338 // check if the `embedding` option is disabled
2339 if metadata.embeddings {
2340 metadata.embeddings = false;
2341
2342 if !should_update {
2343 should_update = true;
2344 }
2345 }
2346
2347 if should_update {
2348 #[cfg(feature = "logging")]
2349 info!(target: "stdout", "Update the model metadata.");
2350
2351 // update the target graph with the new metadata
2352 update_model_metadata(chat_request.model.as_ref(), &metadata)?;
2353 }
2354
2355 Ok(metadata)
2356}
2357
2358fn update_n_predict(
2359 chat_request: &RequestOfModelResponse,
2360 metadata: &mut GgmlMetadata,
2361 available_completion_tokens: u64,
2362) -> Result<(), LlamaCoreError> {
2363 let mut should_update = false;
2364
2365 #[cfg(feature = "logging")]
2366 info!(target: "stdout", "n_predict: {}", metadata.n_predict);
2367
2368 // From high to low priority
2369 // 1. chat_request.max_completion_tokens
2370 // 2. available_completion_tokens
2371 // 3. n_predict
2372
2373 if let Some(max_output_tokens) = chat_request.max_output_tokens {
2374 if metadata.n_predict != max_output_tokens {
2375 #[cfg(feature = "logging")]
2376 info!(target: "stdout", "Update n_predict with max_output_tokens from {} to {}", metadata.n_predict, max_output_tokens);
2377
2378 metadata.n_predict = max_output_tokens;
2379
2380 if !should_update {
2381 should_update = true;
2382 }
2383 }
2384 }
2385
2386 // TODO: remove this condition after [Issue #3958 on WasmEdge](https://github.com/WasmEdge/WasmEdge/issues/3958) is fixed
2387 if metadata.n_predict == -2 {
2388 #[cfg(feature = "logging")]
2389 info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2390
2391 // update n_predict
2392 metadata.n_predict = available_completion_tokens as i32;
2393
2394 if !should_update {
2395 should_update = true;
2396 }
2397 }
2398
2399 if metadata.n_predict == -1
2400 || (metadata.n_predict > 0 && metadata.n_predict < available_completion_tokens as i32)
2401 || (metadata.n_predict < 0 && metadata.n_predict != -2)
2402 // TODO: remove this condition after [Issue #3958 on WasmEdge](https://github.com/WasmEdge/WasmEdge/issues/3958) is fixed
2403 {
2404 #[cfg(feature = "logging")]
2405 info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2406
2407 // update n_predict
2408 metadata.n_predict = available_completion_tokens as i32;
2409
2410 if !should_update {
2411 should_update = true;
2412 }
2413 }
2414
2415 if should_update {
2416 #[cfg(feature = "logging")]
2417 info!(target: "stdout", "Update the model metadata.");
2418
2419 // update the target graph with the new metadata
2420 update_model_metadata(chat_request.model.as_ref(), metadata)?;
2421 }
2422
2423 Ok(())
2424}
2425
2426/// Build post-processing for output based on template type
2427fn post_process(
2428 output: impl AsRef<str>,
2429 template_ty: &PromptTemplateType,
2430) -> Result<String, String> {
2431 let output = if *template_ty == PromptTemplateType::Baichuan2 {
2432 if output.as_ref().contains("用户:") {
2433 output.as_ref().trim_end_matches("用户:").trim().to_owned()
2434 } else {
2435 output.as_ref().trim().to_owned()
2436 }
2437 } else if *template_ty == PromptTemplateType::OpenChat {
2438 if output.as_ref().contains("<|end_of_turn|>") {
2439 output
2440 .as_ref()
2441 .trim_end_matches("<|end_of_turn|>")
2442 .trim()
2443 .to_owned()
2444 } else {
2445 output.as_ref().trim().to_owned()
2446 }
2447 } else if *template_ty == PromptTemplateType::GemmaInstruct
2448 || *template_ty == PromptTemplateType::Gemma3
2449 {
2450 let s = output.as_ref().trim();
2451 if s.ends_with("<end_of_turn>") {
2452 s.trim_end_matches("<end_of_turn>").trim().to_owned()
2453 } else {
2454 s.to_owned()
2455 }
2456 } else if *template_ty == PromptTemplateType::ChatML
2457 || *template_ty == PromptTemplateType::ChatMLTool
2458 || *template_ty == PromptTemplateType::InternLM2Tool
2459 || *template_ty == PromptTemplateType::MiniCPMV
2460 {
2461 let mut s = output.as_ref().trim();
2462 if s.ends_with("<|endoftext|>") {
2463 s = s.trim_end_matches("<|endoftext|>").trim();
2464 }
2465
2466 if s.starts_with(":") {
2467 s = s.trim_start_matches(":").trim();
2468 }
2469
2470 // handle Qwen3 empty think tags
2471 let x = {
2472 let pat = r#"<think>
2473
2474</think>
2475"#;
2476 if s.contains(pat) {
2477 let x = s.replace(pat, "");
2478 if x.starts_with("()") {
2479 x.trim_start_matches("()").to_owned()
2480 } else {
2481 x.to_owned()
2482 }
2483 } else {
2484 s.to_owned()
2485 }
2486 };
2487 s = x.trim();
2488
2489 if s.contains("<|im_start|>") && s.contains("<|im_end|>") {
2490 let idx_start = s.find("<|im_start|>").unwrap();
2491 let idx_end = s.find("<|im_end|>").unwrap();
2492
2493 match idx_start <= idx_end {
2494 true => s.split("<|im_start|>").collect::<Vec<_>>()[0]
2495 .trim()
2496 .to_owned(),
2497 false => s.split("<|im_end|>").collect::<Vec<_>>()[0]
2498 .trim()
2499 .to_owned(),
2500 }
2501 } else if s.contains("<|im_start|>") {
2502 s.split("<|im_start|>").collect::<Vec<_>>()[0]
2503 .trim()
2504 .to_owned()
2505 } else if s.contains("<|im_end|>") {
2506 let output = s.trim_end_matches("<|im_end|>").trim();
2507 if output.starts_with(": ") {
2508 output.trim_start_matches(": ").to_owned()
2509 } else {
2510 output.to_owned()
2511 }
2512 } else {
2513 s.to_owned()
2514 }
2515 } else if *template_ty == PromptTemplateType::Zephyr
2516 || *template_ty == PromptTemplateType::MistralLite
2517 || *template_ty == PromptTemplateType::MistralTool
2518 || *template_ty == PromptTemplateType::MistralInstruct
2519 || *template_ty == PromptTemplateType::MistralSmallChat
2520 || *template_ty == PromptTemplateType::MistralSmallTool
2521 || *template_ty == PromptTemplateType::BreezeInstruct
2522 {
2523 if output.as_ref().contains("</s><") {
2524 output.as_ref().trim_end_matches("</s><").trim().to_owned()
2525 } else if output.as_ref().contains("</s>") {
2526 output
2527 .as_ref()
2528 .strip_suffix("</s>")
2529 .unwrap()
2530 .trim()
2531 .to_owned()
2532 } else {
2533 output.as_ref().trim().to_owned()
2534 }
2535 } else if *template_ty == PromptTemplateType::DeepseekChat {
2536 if output.as_ref().contains("<|end_of_sentence|>") {
2537 output
2538 .as_ref()
2539 .trim_end_matches("<|end_of_sentence|>")
2540 .trim()
2541 .replace("<|end_of_sentence|>", " ")
2542 .trim()
2543 .to_owned()
2544 } else {
2545 output.as_ref().trim().to_owned()
2546 }
2547 } else if *template_ty == PromptTemplateType::HumanAssistant {
2548 if output.as_ref().contains("Human:") {
2549 output.as_ref().trim_end_matches("Human:").trim().to_owned()
2550 } else {
2551 output.as_ref().trim().to_owned()
2552 }
2553 } else if *template_ty == PromptTemplateType::SolarInstruct {
2554 let s = output.as_ref().trim();
2555
2556 if s.starts_with("### Answer") {
2557 let s = s.trim_start_matches("###").trim();
2558
2559 if s.starts_with("Answer:\n") {
2560 s.replace("Answer:\n", "Answer: ")
2561 } else {
2562 s.to_owned()
2563 }
2564 } else {
2565 s.to_owned()
2566 }
2567 } else if *template_ty == PromptTemplateType::Llama2Chat
2568 || *template_ty == PromptTemplateType::NemotronTool
2569 || *template_ty == PromptTemplateType::NemotronChat
2570 {
2571 let s = output.as_ref().trim();
2572 if s.ends_with("</s>") {
2573 s.trim_end_matches("</s>").trim().to_owned()
2574 } else {
2575 s.to_owned()
2576 }
2577 } else if *template_ty == PromptTemplateType::Llama3Chat
2578 || *template_ty == PromptTemplateType::GroqLlama3Tool
2579 || *template_ty == PromptTemplateType::Llama3Tool
2580 || *template_ty == PromptTemplateType::FunctionaryV32
2581 {
2582 let s = output.as_ref().trim();
2583 if s.ends_with("<|eot_id|>") {
2584 s.trim_end_matches("<|eot_id|>").trim().to_owned()
2585 } else {
2586 s.to_owned()
2587 }
2588 } else if *template_ty == PromptTemplateType::Phi3Chat {
2589 let s = output.as_ref().trim();
2590 if s.ends_with("<|end|>") {
2591 s.trim_end_matches("<|end|>").trim().to_owned()
2592 } else {
2593 s.to_owned()
2594 }
2595 } else if *template_ty == PromptTemplateType::Phi4Chat {
2596 let mut s = output.as_ref().trim();
2597
2598 if s.starts_with("think>") {
2599 s = s.trim_start_matches("think>").trim();
2600 }
2601
2602 if s.ends_with("<|im_end|>") {
2603 s.trim_end_matches("<|im_end|>").trim().to_owned()
2604 } else if s.ends_with("<|end|>") {
2605 s.trim_end_matches("<|end|>").trim().to_owned()
2606 } else {
2607 s.to_owned()
2608 }
2609 } else if *template_ty == PromptTemplateType::FunctionaryV31 {
2610 let mut s = output.as_ref().trim();
2611 if s.ends_with("<|eot_id|>") {
2612 s = s.trim_end_matches("<|eot_id|>").trim();
2613 }
2614 if s.ends_with("<|eom_id|>") {
2615 s = s.trim_end_matches("<|eom_id|>").trim();
2616 }
2617 s.to_owned()
2618 } else if *template_ty == PromptTemplateType::MoxinChat
2619 || *template_ty == PromptTemplateType::MoxinInstruct
2620 {
2621 let s = output.as_ref().trim();
2622 if s.ends_with("</s>") {
2623 s.trim_end_matches("</s>").trim().to_owned()
2624 } else if s.ends_with("[INST]") {
2625 s.trim_end_matches("[INST]").trim().to_owned()
2626 } else {
2627 s.to_owned()
2628 }
2629 } else if *template_ty == PromptTemplateType::Falcon3 {
2630 let s = output.as_ref().trim();
2631 if s.ends_with("<|endoftext|>") {
2632 s.trim_end_matches("<|endoftext|>").trim().to_owned()
2633 } else {
2634 s.to_owned()
2635 }
2636 } else if *template_ty == PromptTemplateType::Megrez {
2637 let s = output.as_ref().trim();
2638 if s.ends_with("<|turn_end|>") {
2639 s.trim_end_matches("<|turn_end|>").trim().to_owned()
2640 } else {
2641 s.to_owned()
2642 }
2643 } else if *template_ty == PromptTemplateType::Qwen2vl
2644 || *template_ty == PromptTemplateType::Qwen3NoThink
2645 || *template_ty == PromptTemplateType::ChatMLThink
2646 {
2647 let mut s = output.as_ref().trim();
2648
2649 if s.starts_with(":") {
2650 s = s.trim_start_matches(":").trim();
2651 }
2652
2653 if s.starts_with("</think>") {
2654 s = s.trim_start_matches("</think>").trim();
2655 }
2656
2657 if s.ends_with("<|im_end|>") {
2658 s.trim_end_matches("<|im_end|>").trim().to_owned()
2659 } else {
2660 s.to_owned()
2661 }
2662 } else if *template_ty == PromptTemplateType::VicunaLlava {
2663 let s = output.as_ref().trim();
2664 if s.ends_with("</s>") {
2665 s.trim_end_matches("</s>").trim().to_owned()
2666 } else {
2667 s.to_owned()
2668 }
2669 } else if *template_ty == PromptTemplateType::ExaoneDeepChat
2670 || *template_ty == PromptTemplateType::ExaoneChat
2671 {
2672 let mut s = output.as_ref().trim();
2673
2674 if s.ends_with("[|endofturn|]") {
2675 s = s.trim_end_matches("[|endofturn|]").trim();
2676 }
2677
2678 s.to_owned()
2679 } else if *template_ty == PromptTemplateType::Llama4Chat {
2680 let mut s = output.as_ref().trim();
2681
2682 if s.ends_with("<|eot|>") {
2683 s = s.trim_end_matches("<|eot|>").trim();
2684 }
2685
2686 s.to_owned()
2687 } else if *template_ty == PromptTemplateType::Smolvl {
2688 let mut s = output.as_ref().trim();
2689
2690 if s.starts_with(":") {
2691 s = s.trim_start_matches(":").trim();
2692 }
2693
2694 if s.ends_with("<end_of_utterance>") {
2695 s = s.trim_end_matches("<end_of_utterance>").trim();
2696 }
2697
2698 if s.contains("<end_of_utterance>:") {
2699 let parts = s.split("<end_of_utterance>:").collect::<Vec<_>>();
2700 parts.last().unwrap().trim().to_owned()
2701 } else {
2702 s.to_owned()
2703 }
2704 } else if *template_ty == PromptTemplateType::Smol3NoThink {
2705 let mut s = output.as_ref().trim();
2706
2707 if s.ends_with("<|im_end|>") {
2708 s = s.trim_end_matches("<|im_end|>").trim();
2709 }
2710
2711 let re = regex::Regex::new(r"(?s)^<think>.*?</think>\s*").unwrap();
2712 re.replace(s, "").to_string()
2713 } else if *template_ty == PromptTemplateType::GptOss {
2714 let s = output.as_ref().trim();
2715
2716 let re =
2717 regex::Regex::new(r"(?s).*<\|channel\|>final<\|message\|>(.*?)<\|return\|>$").unwrap();
2718
2719 if let Some(caps) = re.captures(s) {
2720 let extracted = &caps[1];
2721 extracted.to_owned()
2722 } else {
2723 s.to_owned()
2724 }
2725 } else if *template_ty == PromptTemplateType::Qwen3Agent {
2726 let mut s = output.as_ref().trim();
2727
2728 if s.starts_with(":") {
2729 s = s.trim_start_matches(":").trim();
2730 }
2731
2732 if s.starts_with("</think>") {
2733 s = s.trim_start_matches("</think>").trim();
2734 }
2735
2736 if s.ends_with("<|im_end|>") {
2737 s = s.trim_end_matches("<|im_end|>").trim();
2738 }
2739
2740 if s.contains("<final_answer>") && !s.contains("</final_answer>") {
2741 format!("{s}</final_answer>")
2742 } else {
2743 s.to_owned()
2744 }
2745 } else if *template_ty == PromptTemplateType::SeedOssNoThink {
2746 let s = output.as_ref().trim();
2747
2748 let re = regex::Regex::new(r"(?s)</seed:think>\s*(.*?)\s*<seed:eos>").unwrap();
2749
2750 if let Some(caps) = re.captures(s) {
2751 let extracted = &caps[1];
2752 extracted.to_owned()
2753 } else {
2754 s.to_owned()
2755 }
2756 } else {
2757 output.as_ref().trim().to_owned()
2758 };
2759
2760 Ok(output)
2761}
2762
2763/// Build the chat prompt from the chat messages.
2764///
2765/// # Arguments
2766///
2767/// * `model_name`: The name of the model.
2768///
2769/// * `chat_request`: The chat request.
2770///
2771/// # Returns
2772///
2773/// A tuple containing the prompt, the number of available tokens for completions, and a boolean indicating whether tools are used.
2774fn build_prompt(
2775 model_name: Option<&String>,
2776 chat_request: &mut RequestOfModelResponse,
2777) -> Result<(String, u64, bool), LlamaCoreError> {
2778 let metadata = get_model_metadata(model_name)?;
2779 let ctx_size = metadata.ctx_size as u64;
2780 let chat_prompt = ChatPrompt::from(metadata.prompt_template);
2781
2782 // compute max prompt tokens, which is 80% of the context size
2783 let max_prompt_tokens = ctx_size * 4 / 5;
2784
2785 // convert chat_request.input to a user message and append it to chat_request.messages
2786 let mut chat_completions_messages = to_chat_messages(chat_request.input.as_ref().unwrap())?;
2787
2788 #[cfg(feature = "logging")]
2789 debug!(target: "stdout", "converted chat messages: {chat_completions_messages:?}");
2790
2791 loop {
2792 // if chat_request.messages.is_empty() {
2793 if chat_request.input.is_none() {
2794 let err_msg = "The `input` field of the request is empty.";
2795
2796 #[cfg(feature = "logging")]
2797 error!(target: "stdout", "{err_msg}");
2798
2799 return Err(LlamaCoreError::Operation(err_msg.to_owned()));
2800 }
2801
2802 let (prompt, tool_use) = match &chat_request.tool_choice {
2803 ToolChoice::None => {
2804 match chat_prompt.build_with_tools(&mut chat_completions_messages, Some(&[])) {
2805 Ok(prompt) => (prompt, false),
2806 Err(e) => {
2807 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
2808
2809 #[cfg(feature = "logging")]
2810 error!(target: "stdout", "{}", &err_msg);
2811
2812 return Err(LlamaCoreError::Operation(err_msg));
2813 }
2814 }
2815 }
2816 _ => todo!("Other tool choices are not supported yet."),
2817 };
2818 #[cfg(feature = "logging")]
2819 info!(target: "stdout", "Try to set prompt: {prompt}");
2820
2821 // set prompt
2822 set_prompt(model_name, &prompt)?;
2823
2824 // Retrieve the number of prompt tokens.
2825 let token_info = get_token_info_by_graph_name(model_name)?;
2826
2827 match token_info.prompt_tokens > max_prompt_tokens {
2828 true => {
2829 match chat_completions_messages[0].role() {
2830 ChatCompletionRole::System => {
2831 // corner case: context size is too small, `system -> user -> assistant -> tool` cannot be trimmed.
2832 if chat_completions_messages.len() == 4
2833 && chat_completions_messages[1].role() == ChatCompletionRole::User
2834 && chat_completions_messages[2].role() == ChatCompletionRole::Assistant
2835 && chat_completions_messages[3].role() == ChatCompletionRole::Tool
2836 {
2837 let err_msg = format!(
2838 "The number of prompt tokens ({}) is greater than the max prompt tokens ({}). Please increase the context size.",
2839 token_info.prompt_tokens, max_prompt_tokens
2840 );
2841
2842 #[cfg(feature = "logging")]
2843 error!(target: "stdout", "{}", &err_msg);
2844
2845 return Err(LlamaCoreError::Operation(err_msg));
2846 }
2847
2848 if chat_completions_messages.len() > 2 {
2849 #[cfg(feature = "logging")]
2850 info!(target: "stdout", "Prune chat history: current length {}", chat_completions_messages.len());
2851
2852 // remove user_1 if it exists
2853 // For example, `system -> user_1 -> ... -> user_2 -> ... -> user_latest` will be converted to `system -> ... -> user_2 -> ... -> user_latest`
2854 if chat_completions_messages[1].role() == ChatCompletionRole::User {
2855 let user_message = chat_completions_messages.remove(1);
2856
2857 #[cfg(feature = "logging")]
2858 info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
2859 }
2860
2861 // remove all messages until the message is of `user`
2862 // For example, `system -> ... -> user_2 -> ... -> user_latest` will be converted to `system -> user_2 -> ... -> user_latest`
2863 while chat_completions_messages[1].role() != ChatCompletionRole::User {
2864 let message = chat_completions_messages.remove(1);
2865
2866 #[cfg(feature = "logging")]
2867 info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
2868
2869 if chat_completions_messages.len() == 1 {
2870 let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
2871
2872 #[cfg(feature = "logging")]
2873 error!(target: "stdout", "{err_msg}");
2874
2875 return Err(LlamaCoreError::Operation(err_msg));
2876 }
2877 }
2878 } else if token_info.prompt_tokens > ctx_size {
2879 let err_msg = format!(
2880 "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
2881 token_info.prompt_tokens, ctx_size
2882 );
2883
2884 #[cfg(feature = "logging")]
2885 error!(target: "stdout", "{}", &err_msg);
2886
2887 return Err(LlamaCoreError::Operation(err_msg));
2888 } else {
2889 return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
2890 }
2891 }
2892 ChatCompletionRole::User => {
2893 // corner case: context size is too small, `user -> assistant -> tool` cannot be trimmed.
2894 if chat_completions_messages.len() == 3
2895 && chat_completions_messages[1].role() == ChatCompletionRole::User
2896 && chat_completions_messages[2].role() == ChatCompletionRole::Assistant
2897 && chat_completions_messages[3].role() == ChatCompletionRole::Tool
2898 {
2899 let err_msg = format!(
2900 "The number of prompt tokens ({}) is greater than the max prompt tokens ({}). Please increase the context size.",
2901 token_info.prompt_tokens, max_prompt_tokens
2902 );
2903
2904 #[cfg(feature = "logging")]
2905 error!(target: "stdout", "{}", &err_msg);
2906
2907 return Err(LlamaCoreError::Operation(err_msg));
2908 }
2909
2910 if chat_completions_messages.len() > 1 {
2911 // user_1 -> ... -> user_2 -> ... -> user_latest
2912
2913 // remove user_1 if it exists
2914 // For example, `user_1 -> ... -> user_2 -> ... -> user_latest` will be converted to `... -> user_2 -> ... -> user_latest`
2915 if chat_completions_messages[0].role() == ChatCompletionRole::User {
2916 let user_message = chat_completions_messages.remove(0);
2917
2918 #[cfg(feature = "logging")]
2919 info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
2920 }
2921
2922 // remove all messages until the message is of `user`
2923 // For example, `... -> user_2 -> ... -> user_latest` will be converted to `user_2 -> ... -> user_latest`
2924 while chat_completions_messages[0].role() != ChatCompletionRole::User {
2925 let message = chat_completions_messages.remove(0);
2926
2927 #[cfg(feature = "logging")]
2928 info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
2929
2930 if chat_completions_messages.is_empty() {
2931 let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
2932
2933 #[cfg(feature = "logging")]
2934 error!(target: "stdout", "{err_msg}");
2935
2936 return Err(LlamaCoreError::Operation(err_msg));
2937 }
2938 }
2939 } else if token_info.prompt_tokens > ctx_size {
2940 let err_msg = format!(
2941 "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
2942 token_info.prompt_tokens, ctx_size
2943 );
2944
2945 #[cfg(feature = "logging")]
2946 error!(target: "stdout", "{}", &err_msg);
2947
2948 return Err(LlamaCoreError::Operation(err_msg));
2949 } else {
2950 return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
2951 }
2952 }
2953 _ => {
2954 #[cfg(feature = "logging")]
2955 info!(target: "stdout", "remove a {} message from the message queue", chat_completions_messages[0].role());
2956
2957 chat_completions_messages.remove(0);
2958 }
2959 }
2960
2961 continue;
2962 }
2963 false => return Ok((prompt, ctx_size - max_prompt_tokens, tool_use)),
2964 }
2965 }
2966}
2967
2968fn set_prompt(model_name: Option<&String>, prompt: impl AsRef<str>) -> Result<(), LlamaCoreError> {
2969 let chat_graphs = match CHAT_GRAPHS.get() {
2970 Some(chat_graphs) => chat_graphs,
2971 None => {
2972 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2973
2974 #[cfg(feature = "logging")]
2975 error!(target: "stdout", "{}", &err_msg);
2976
2977 return Err(LlamaCoreError::Operation(err_msg.into()));
2978 }
2979 };
2980
2981 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
2982 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
2983
2984 #[cfg(feature = "logging")]
2985 error!(target: "stdout", "{}", &err_msg);
2986
2987 LlamaCoreError::Operation(err_msg)
2988 })?;
2989
2990 match model_name {
2991 Some(model_name) => {
2992 #[cfg(feature = "logging")]
2993 info!(target: "stdout", "Set prompt to the chat model named {model_name}");
2994
2995 match chat_graphs.contains_key(model_name) {
2996 true => {
2997 let graph = chat_graphs.get_mut(model_name).unwrap();
2998 let tensor_data = prompt.as_ref().as_bytes().to_vec();
2999 set_tensor_data_u8(graph, 0, &tensor_data)
3000 }
3001 false => match chat_graphs.iter_mut().next() {
3002 Some((_, graph)) => {
3003 let tensor_data = prompt.as_ref().as_bytes().to_vec();
3004 set_tensor_data_u8(graph, 0, &tensor_data)
3005 }
3006 None => {
3007 let err_msg = "There is no model available in the chat graphs.";
3008
3009 #[cfg(feature = "logging")]
3010 error!(target: "stdout", "{}", &err_msg);
3011
3012 Err(LlamaCoreError::Operation(err_msg.into()))
3013 }
3014 },
3015 }
3016 }
3017 None => {
3018 #[cfg(feature = "logging")]
3019 info!(target: "stdout", "Set prompt to the default chat model.");
3020
3021 match chat_graphs.iter_mut().next() {
3022 Some((_, graph)) => {
3023 let tensor_data = prompt.as_ref().as_bytes().to_vec();
3024 set_tensor_data_u8(graph, 0, &tensor_data)
3025 }
3026 None => {
3027 let err_msg = "There is no model available in the chat graphs while trying to set prompt to the default model.";
3028
3029 #[cfg(feature = "logging")]
3030 error!(target: "stdout", "{err_msg}");
3031
3032 Err(LlamaCoreError::Operation(err_msg.into()))
3033 }
3034 }
3035 }
3036 }
3037}
3038
3039/// Get a copy of the metadata of the model.
3040fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlMetadata, LlamaCoreError> {
3041 let chat_graphs = match CHAT_GRAPHS.get() {
3042 Some(chat_graphs) => chat_graphs,
3043 None => {
3044 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3045
3046 #[cfg(feature = "logging")]
3047 error!(target: "stdout", "{err_msg}");
3048
3049 return Err(LlamaCoreError::Operation(err_msg.into()));
3050 }
3051 };
3052
3053 let chat_graphs = chat_graphs.lock().map_err(|e| {
3054 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3055
3056 #[cfg(feature = "logging")]
3057 error!(target: "stdout", "{}", &err_msg);
3058
3059 LlamaCoreError::Operation(err_msg)
3060 })?;
3061
3062 match model_name {
3063 Some(model_name) => match chat_graphs.contains_key(model_name) {
3064 true => {
3065 let graph = chat_graphs.get(model_name).unwrap();
3066 Ok(graph.metadata.clone())
3067 }
3068 false => match chat_graphs.iter().next() {
3069 Some((_, graph)) => Ok(graph.metadata.clone()),
3070 None => {
3071 let err_msg = "There is no model available in the chat graphs.";
3072
3073 #[cfg(feature = "logging")]
3074 error!(target: "stdout", "{}", &err_msg);
3075
3076 Err(LlamaCoreError::Operation(err_msg.into()))
3077 }
3078 },
3079 },
3080 None => match chat_graphs.iter().next() {
3081 Some((_, graph)) => Ok(graph.metadata.clone()),
3082 None => {
3083 let err_msg = "There is no model available in the chat graphs.";
3084
3085 #[cfg(feature = "logging")]
3086 error!(target: "stdout", "{err_msg}");
3087
3088 Err(LlamaCoreError::Operation(err_msg.into()))
3089 }
3090 },
3091 }
3092}
3093
3094fn update_model_metadata(
3095 model_name: Option<&String>,
3096 metadata: &GgmlMetadata,
3097) -> Result<(), LlamaCoreError> {
3098 let config = match serde_json::to_string(metadata) {
3099 Ok(config) => config,
3100 Err(e) => {
3101 let err_msg = format!("Fail to serialize metadata to a JSON string. {e}");
3102
3103 #[cfg(feature = "logging")]
3104 error!(target: "stdout", "{}", &err_msg);
3105
3106 return Err(LlamaCoreError::Operation(err_msg));
3107 }
3108 };
3109
3110 let chat_graphs = match CHAT_GRAPHS.get() {
3111 Some(chat_graphs) => chat_graphs,
3112 None => {
3113 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3114
3115 #[cfg(feature = "logging")]
3116 error!(target: "stdout", "{err_msg}");
3117
3118 return Err(LlamaCoreError::Operation(err_msg.into()));
3119 }
3120 };
3121
3122 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3123 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. Reason: {e}");
3124
3125 #[cfg(feature = "logging")]
3126 error!(target: "stdout", "{}", &err_msg);
3127
3128 LlamaCoreError::Operation(err_msg)
3129 })?;
3130
3131 match model_name {
3132 Some(model_name) => {
3133 match chat_graphs.contains_key(model_name) {
3134 true => {
3135 let graph = chat_graphs.get_mut(model_name).unwrap();
3136 // update metadata
3137 set_tensor_data_u8(graph, 1, config.as_bytes())
3138 }
3139 false => match chat_graphs.iter_mut().next() {
3140 Some((_, graph)) => {
3141 // update metadata
3142 set_tensor_data_u8(graph, 1, config.as_bytes())
3143 }
3144 None => {
3145 let err_msg = "There is no model available in the chat graphs.";
3146
3147 #[cfg(feature = "logging")]
3148 error!(target: "stdout", "{}", &err_msg);
3149
3150 Err(LlamaCoreError::Operation(err_msg.into()))
3151 }
3152 },
3153 }
3154 }
3155 None => {
3156 match chat_graphs.iter_mut().next() {
3157 Some((_, graph)) => {
3158 // update metadata
3159 set_tensor_data_u8(graph, 1, config.as_bytes())
3160 }
3161 None => {
3162 let err_msg = "There is no model available in the chat graphs.";
3163
3164 #[cfg(feature = "logging")]
3165 error!(target: "stdout", "{err_msg}");
3166
3167 Err(LlamaCoreError::Operation(err_msg.into()))
3168 }
3169 }
3170 }
3171 }
3172}
3173
3174fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
3175 // get metadata
3176 let metadata = get_model_metadata(model_name)?;
3177
3178 // update model with the original metadata
3179 update_model_metadata(model_name, &metadata)
3180}
3181
3182#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3183#[allow(dead_code)]
3184enum ContextFullState {
3185 Message,
3186 Usage,
3187 Done,
3188 EndOfSequence,
3189}
3190
3191#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3192#[allow(dead_code)]
3193enum StreamState {
3194 Usage,
3195 NoUsage,
3196 Done,
3197 EndOfSequence,
3198}
3199
3200#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3201#[allow(dead_code)]
3202enum PromptTooLongState {
3203 Message,
3204 Usage,
3205 Done,
3206 EndOfSequence,
3207}
3208
3209struct ChatStream {
3210 // id: String,
3211 model: Option<String>,
3212 // include_usage: bool,
3213 context_full_state: ContextFullState,
3214 prompt_too_long_state: PromptTooLongState,
3215 stream_state: StreamState,
3216 cache: Option<VecDeque<String>>,
3217 is_waiting: bool,
3218 has_lock: bool,
3219}
3220impl ChatStream {
3221 fn new(
3222 model: Option<String>,
3223 // id: String,
3224 // include_usage: bool,
3225 cache: Option<Vec<String>>,
3226 ) -> Self {
3227 // Try to acquire lock
3228 let has_lock = CHAT_STREAM_ACTIVE
3229 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
3230 .is_ok();
3231
3232 #[cfg(feature = "logging")]
3233 if !has_lock {
3234 info!(target: "stdout", "Lock acquisition failed in ChatStream::new, creating with waiting status");
3235 }
3236
3237 ChatStream {
3238 // id,
3239 model,
3240 // include_usage,
3241 context_full_state: ContextFullState::Message,
3242 prompt_too_long_state: PromptTooLongState::Message,
3243 stream_state: StreamState::Usage,
3244 cache: cache.map(VecDeque::from),
3245 is_waiting: !has_lock,
3246 has_lock,
3247 }
3248 }
3249
3250 // Try to acquire lock, returns whether successful
3251 fn try_acquire_lock(&mut self) -> bool {
3252 if self.has_lock {
3253 return true;
3254 }
3255
3256 let acquired = CHAT_STREAM_ACTIVE
3257 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
3258 .is_ok();
3259
3260 if acquired {
3261 self.has_lock = true;
3262 self.is_waiting = false;
3263 }
3264
3265 acquired
3266 }
3267}
3268impl Drop for ChatStream {
3269 fn drop(&mut self) {
3270 // Clean up is only needed if we have the lock or if stream was actually used
3271 if self.has_lock || (self.cache.is_none() && !self.is_waiting) {
3272 #[cfg(feature = "logging")]
3273 info!(target: "stdout", "Cleaning up context for ChatStream");
3274
3275 match &self.model {
3276 Some(model_name) => {
3277 match CHAT_GRAPHS.get() {
3278 Some(chat_graphs) => {
3279 match chat_graphs.lock() {
3280 Ok(mut chat_graphs) => match chat_graphs.contains_key(model_name) {
3281 true => {
3282 let graph = chat_graphs.get_mut(model_name).unwrap();
3283
3284 // clean up the context
3285 if let Err(e) = graph.finish_single() {
3286 let err_msg = format!(
3287 "Failed to clean up the context. Reason: {e}"
3288 );
3289
3290 #[cfg(feature = "logging")]
3291 error!(target: "stdout", "{}", &err_msg);
3292
3293 #[cfg(not(feature = "logging"))]
3294 println!(
3295 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3296 &err_msg
3297 );
3298 }
3299 }
3300 false => match chat_graphs.iter_mut().next() {
3301 Some((_, graph)) => {
3302 // clean up the context
3303 if let Err(e) = graph.finish_single() {
3304 let err_msg = format!(
3305 "Failed to clean up the context. Reason: {e}"
3306 );
3307
3308 #[cfg(feature = "logging")]
3309 error!(target: "stdout", "{}", &err_msg);
3310
3311 #[cfg(not(feature = "logging"))]
3312 println!(
3313 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3314 &err_msg
3315 );
3316 }
3317 }
3318 None => {
3319 let err_msg =
3320 "There is no model available in the chat graphs.";
3321
3322 #[cfg(feature = "logging")]
3323 error!(target: "stdout", "{}", &err_msg);
3324
3325 #[cfg(not(feature = "logging"))]
3326 println!(
3327 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3328 &err_msg
3329 );
3330 }
3331 },
3332 },
3333 Err(e) => {
3334 let err_msg =
3335 format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3336
3337 #[cfg(feature = "logging")]
3338 error!(target: "stdout", "{}", &err_msg);
3339
3340 #[cfg(not(feature = "logging"))]
3341 println!(
3342 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3343 &err_msg
3344 );
3345 }
3346 }
3347 }
3348 None => {
3349 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3350
3351 #[cfg(feature = "logging")]
3352 error!(target: "stdout", "{}", &err_msg);
3353
3354 #[cfg(not(feature = "logging"))]
3355 println!(
3356 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3357 &err_msg
3358 );
3359 }
3360 };
3361 }
3362 None => {
3363 match CHAT_GRAPHS.get() {
3364 Some(chat_graphs) => {
3365 match chat_graphs.lock() {
3366 Ok(mut chat_graphs) => match chat_graphs.iter_mut().next() {
3367 Some((_, graph)) => {
3368 // clean up the context
3369 if let Err(e) = graph.finish_single() {
3370 let err_msg = format!(
3371 "Failed to clean up the context. Reason: {e}"
3372 );
3373
3374 #[cfg(feature = "logging")]
3375 error!(target: "stdout", "{}", &err_msg);
3376
3377 #[cfg(not(feature = "logging"))]
3378 println!(
3379 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3380 &err_msg
3381 );
3382 }
3383 }
3384 None => {
3385 let err_msg =
3386 "There is no model available in the chat graphs.";
3387
3388 #[cfg(feature = "logging")]
3389 error!(target: "stdout", "{err_msg}");
3390
3391 #[cfg(not(feature = "logging"))]
3392 println!(
3393 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3394 err_msg
3395 );
3396 }
3397 },
3398 Err(e) => {
3399 let err_msg =
3400 format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3401
3402 #[cfg(feature = "logging")]
3403 error!(target: "stdout", "{}", &err_msg);
3404
3405 #[cfg(not(feature = "logging"))]
3406 println!(
3407 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3408 &err_msg
3409 );
3410 }
3411 }
3412 }
3413 None => {
3414 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3415
3416 #[cfg(feature = "logging")]
3417 error!(target: "stdout", "{}", &err_msg);
3418
3419 #[cfg(not(feature = "logging"))]
3420 println!(
3421 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3422 &err_msg
3423 );
3424 }
3425 };
3426 }
3427 }
3428
3429 #[cfg(feature = "logging")]
3430 info!(target: "stdout", "Model context cleanup done!");
3431 }
3432
3433 // reset the model metadata
3434 if let Err(e) = reset_model_metadata(self.model.as_ref()) {
3435 let err_msg = format!("Fail to reset model metadata. Reason: {e}");
3436
3437 #[cfg(feature = "logging")]
3438 error!(target: "stdout", "{}", &err_msg);
3439
3440 #[cfg(not(feature = "logging"))]
3441 println!("[ERROR][llama_core] {}", &err_msg);
3442 }
3443 #[cfg(feature = "logging")]
3444 info!(target: "stdout", "Model metadata reset done!");
3445
3446 // When dropping a ChatStream that held the lock, check if there are waiting streams
3447 if self.has_lock {
3448 // Reset the atomic flag
3449 CHAT_STREAM_ACTIVE.store(false, Ordering::SeqCst);
3450
3451 #[cfg(feature = "logging")]
3452 info!(target: "stdout", "Lock from ChatStream released");
3453
3454 // Wake up waiting streams
3455 if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3456 if let Some(waker) = queue.pop_front() {
3457 #[cfg(feature = "logging")]
3458 info!(target: "stdout", "Waking up a waiting ChatStream");
3459
3460 waker.wake();
3461 }
3462 }
3463 }
3464 }
3465}
3466impl futures::Stream for ChatStream {
3467 type Item = Result<String, LlamaCoreError>;
3468
3469 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
3470 let this = self.get_mut();
3471
3472 // If this is a waiting stream, try to acquire the lock
3473 if this.is_waiting {
3474 if !this.try_acquire_lock() {
3475 // Store the waker to be notified when the lock becomes available
3476 if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3477 // Remove any previous instance of this waker
3478 queue.retain(|w| !w.will_wake(cx.waker()));
3479 // Add this waker to the queue
3480 queue.push_back(cx.waker().clone());
3481
3482 #[cfg(feature = "logging")]
3483 debug!(target: "stdout", "ChatStream is waiting for lock, added waker to queue");
3484 }
3485
3486 return Poll::Pending;
3487 }
3488
3489 #[cfg(feature = "logging")]
3490 info!(target: "stdout", "ChatStream acquired lock and is now active");
3491 // If we got here, we successfully acquired the lock and can proceed
3492 }
3493
3494 // Ensure we still have the lock
3495 if !this.has_lock && !this.try_acquire_lock() {
3496 // Lost the lock, need to wait
3497 this.is_waiting = true;
3498
3499 // Register waker to be notified when lock is available
3500 if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3501 queue.retain(|w| !w.will_wake(cx.waker()));
3502 queue.push_back(cx.waker().clone());
3503 }
3504
3505 return Poll::Pending;
3506 }
3507
3508 if this.cache.is_none() {
3509 let res = compute_stream(
3510 this.model.clone(),
3511 // this.id.clone(),
3512 // this.include_usage,
3513 &mut this.prompt_too_long_state,
3514 &mut this.context_full_state,
3515 &mut this.stream_state,
3516 );
3517
3518 match res {
3519 Ok(x) => {
3520 #[cfg(feature = "logging")]
3521 info!(target: "stdout", "next item for ChatStream: {}", &x);
3522
3523 if x != "[GGML] End of sequence" && !x.is_empty() {
3524 Poll::Ready(Some(Ok(x)))
3525 } else {
3526 // stopped
3527 Poll::Ready(None)
3528 }
3529 }
3530 Err(e) => Poll::Ready(Some(Err(e))),
3531 }
3532 } else {
3533 let x = this.cache.as_mut().unwrap().pop_front();
3534
3535 #[cfg(feature = "logging")]
3536 info!(target: "stdout", "Get the next item from the cache for ChatStream: {:?}", &x);
3537
3538 match x {
3539 Some(x) => Poll::Ready(Some(Ok(x))),
3540 None => Poll::Ready(None),
3541 }
3542 }
3543 }
3544}
3545
3546/// Helper function to get or initialize the waker queue for waiting ChatStreams
3547fn get_chat_stream_waker_queue() -> &'static Mutex<VecDeque<Waker>> {
3548 CHAT_STREAM_WAKER_QUEUE.get_or_init(|| {
3549 #[cfg(feature = "logging")]
3550 info!(target: "stdout", "Initializing ChatStream waker queue");
3551 Mutex::new(VecDeque::new())
3552 })
3553}
3554
3555#[allow(unused_variables)]
3556fn compute_stream(
3557 model_name: Option<String>,
3558 // id: String,
3559 // include_usage: bool,
3560 prompt_too_long_state: &mut PromptTooLongState,
3561 context_full_state: &mut ContextFullState,
3562 stream_state: &mut StreamState,
3563) -> Result<String, LlamaCoreError> {
3564 {
3565 // #[cfg(feature = "logging")]
3566 // info!(target: "stdout", "Computing stream chunk for ChatStream {}", &id);
3567
3568 // #[cfg(feature = "logging")]
3569 // debug!(target: "stdout", "prompt_too_long_state: {:?}", *prompt_too_long_state);
3570 // #[cfg(feature = "logging")]
3571 // debug!(target: "stdout", "context_full_state: {:?}", *context_full_state);
3572 // #[cfg(feature = "logging")]
3573 // debug!(target: "stdout", "stream_state: {:?}", *stream_state);
3574
3575 // if *prompt_too_long_state == PromptTooLongState::EndOfSequence
3576 // || *context_full_state == ContextFullState::EndOfSequence
3577 // || *stream_state == StreamState::EndOfSequence
3578 // {
3579 // #[cfg(feature = "logging")]
3580 // info!(target: "stdout", "Return the chat stream chunk!");
3581
3582 // return Ok("[GGML] End of sequence".to_string());
3583 // }
3584
3585 // let chat_graphs = match CHAT_GRAPHS.get() {
3586 // Some(chat_graphs) => chat_graphs,
3587 // None => {
3588 // let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3589
3590 // #[cfg(feature = "logging")]
3591 // error!(target: "stdout", "{}", &err_msg);
3592
3593 // return Err(LlamaCoreError::Operation(err_msg.into()));
3594 // }
3595 // };
3596
3597 // // We're already holding the ChatStream lock, so we know we have exclusive access to the graph
3598 // let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3599 // let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3600
3601 // #[cfg(feature = "logging")]
3602 // error!(target: "stdout", "{}", &err_msg);
3603
3604 // LlamaCoreError::Operation(err_msg)
3605 // })?;
3606
3607 // // Get the graph based on model name
3608 // let res = match &model_name {
3609 // Some(model_name) => {
3610 // match chat_graphs.contains_key(model_name) {
3611 // true => {
3612 // let graph = chat_graphs.get_mut(model_name).unwrap();
3613 // // compute
3614 // match graph.compute_single() {
3615 // Ok(_) => {
3616 // #[cfg(feature = "logging")]
3617 // debug!(target: "stdout", "Compute the chat stream chunk successfully.");
3618
3619 // // Process according to state
3620 // match stream_state {
3621 // StreamState::Usage | StreamState::NoUsage => {
3622 // // Retrieve the output
3623 // let output_buffer =
3624 // get_output_buffer_single(graph, OUTPUT_TENSOR)?;
3625
3626 // #[cfg(feature = "logging")]
3627 // info!(target: "stdout", "retrieved the output buffer");
3628
3629 // // decode the output buffer to a utf8 string
3630 // let output = match String::from_utf8(output_buffer.clone()) {
3631 // Ok(token) => token,
3632 // Err(_) => {
3633 // let mutex = CACHED_UTF8_ENCODINGS
3634 // .get_or_init(|| Mutex::new(Vec::new()));
3635 // let mut cached_encodings = mutex.lock().map_err(|e| {
3636 // let err_msg = format!(
3637 // "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
3638 // );
3639
3640 // #[cfg(feature = "logging")]
3641 // error!(target: "stdout", "{}", &err_msg);
3642
3643 // LlamaCoreError::Operation(err_msg)
3644 // })?;
3645
3646 // // cache the bytes for future decoding
3647 // cached_encodings.extend_from_slice(&output_buffer[..]);
3648
3649 // match String::from_utf8(cached_encodings.to_vec()) {
3650 // Ok(token) => {
3651 // // clear CACHED_UTF8_ENCODINGS
3652 // cached_encodings.clear();
3653
3654 // token
3655 // }
3656 // Err(e) => {
3657 // // TODO This is a temp check. In case, infinite cached encodings happen.
3658 // if cached_encodings.len() > 4 {
3659 // let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
3660
3661 // #[cfg(feature = "logging")]
3662 // error!(target: "stdout", "{}", &err_msg);
3663
3664 // #[cfg(feature = "logging")]
3665 // error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
3666
3667 // // let token = String::from_utf8_lossy(
3668 // // &cached_encodings,
3669 // // )
3670 // // .to_string();
3671
3672 // // clear CACHED_UTF8_ENCODINGS
3673 // cached_encodings.clear();
3674
3675 // String::from("")
3676 // } else {
3677 // let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
3678
3679 // #[cfg(feature = "logging")]
3680 // warn!(target: "stdout", "{}", &warn_msg);
3681
3682 // String::from("")
3683 // }
3684 // }
3685 // }
3686 // }
3687 // };
3688
3689 // #[cfg(feature = "logging")]
3690 // info!(target: "stdout", "decoded the output buffer");
3691
3692 // let created = SystemTime::now()
3693 // .duration_since(std::time::UNIX_EPOCH)
3694 // .map_err(|e| {
3695 // let err_msg = format!(
3696 // "Failed to get the current time. Reason: {e}"
3697 // );
3698
3699 // #[cfg(feature = "logging")]
3700 // error!(target: "stdout", "{}", &err_msg);
3701
3702 // LlamaCoreError::Operation(err_msg)
3703 // })?;
3704
3705 // let chat_completion_chunk = ChatCompletionChunk {
3706 // id,
3707 // object: "chat.completion.chunk".to_string(),
3708 // created: created.as_secs(),
3709 // model: graph.name().to_owned(),
3710 // system_fingerprint: "fp_44709d6fcb".to_string(),
3711 // choices: vec![ChatCompletionChunkChoice {
3712 // index: 0,
3713 // delta: ChatCompletionChunkChoiceDelta {
3714 // role: ChatCompletionRole::Assistant,
3715 // content: Some(output),
3716 // tool_calls: vec![],
3717 // },
3718 // logprobs: None,
3719 // finish_reason: None,
3720 // }],
3721 // usage: None,
3722 // };
3723
3724 // #[cfg(feature = "logging")]
3725 // info!(target: "stdout", "created chat completion chunk");
3726
3727 // // serialize chat completion chunk
3728 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
3729 // .map_err(|e| {
3730 // let err_msg = format!(
3731 // "Failed to serialize chat completion chunk. Reason: {e}"
3732 // );
3733
3734 // #[cfg(feature = "logging")]
3735 // error!(target: "stdout", "{}", &err_msg);
3736
3737 // LlamaCoreError::Operation(err_msg)
3738 // })?;
3739
3740 // Ok(format!("data: {chunk_str}\n\n"))
3741 // }
3742 // StreamState::Done => {
3743 // *stream_state = StreamState::EndOfSequence;
3744
3745 // Ok("data: [DONE]\n\n".to_string())
3746 // }
3747 // StreamState::EndOfSequence => {
3748 // Ok("[GGML] End of sequence".to_string())
3749 // }
3750 // }
3751 // }
3752 // Err(wasmedge_wasi_nn::Error::BackendError(
3753 // wasmedge_wasi_nn::BackendError::EndOfSequence,
3754 // )) => {
3755 // #[cfg(feature = "logging")]
3756 // debug!(target: "stdout", "End of sequence");
3757
3758 // match stream_state {
3759 // StreamState::Usage => {
3760 // *stream_state = StreamState::Done;
3761
3762 // // retrieve the number of prompt and completion tokens
3763 // let token_info = get_token_info_by_graph(graph)?;
3764
3765 // let usage = Some(Usage {
3766 // prompt_tokens: token_info.prompt_tokens,
3767 // completion_tokens: token_info.completion_tokens,
3768 // total_tokens: token_info.prompt_tokens
3769 // + token_info.completion_tokens,
3770 // });
3771
3772 // #[cfg(feature = "logging")]
3773 // info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
3774
3775 // let created = SystemTime::now()
3776 // .duration_since(std::time::UNIX_EPOCH)
3777 // .map_err(|e| {
3778 // let err_msg = format!(
3779 // "Failed to get the current time. Reason: {e}"
3780 // );
3781
3782 // #[cfg(feature = "logging")]
3783 // error!(target: "stdout", "{}", &err_msg);
3784
3785 // LlamaCoreError::Operation(err_msg)
3786 // })?;
3787
3788 // let chat_completion_chunk = ChatCompletionChunk {
3789 // id,
3790 // object: "chat.completion.chunk".to_string(),
3791 // created: created.as_secs(),
3792 // model: graph.name().to_owned(),
3793 // system_fingerprint: "fp_44709d6fcb".to_string(),
3794 // choices: vec![],
3795 // usage,
3796 // };
3797
3798 // // serialize chat completion chunk
3799 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
3800 // .map_err(|e| {
3801 // let err_msg = format!(
3802 // "Failed to serialize chat completion chunk. Reason: {e}"
3803 // );
3804
3805 // #[cfg(feature = "logging")]
3806 // error!(target: "stdout", "{}", &err_msg);
3807
3808 // LlamaCoreError::Operation(err_msg)
3809 // })?;
3810
3811 // Ok(format!("data: {chunk_str}\n\n"))
3812 // }
3813 // StreamState::Done | StreamState::NoUsage => {
3814 // *stream_state = StreamState::EndOfSequence;
3815
3816 // Ok("data: [DONE]\n\n".to_string())
3817 // }
3818 // StreamState::EndOfSequence => {
3819 // Ok("[GGML] End of sequence".to_string())
3820 // }
3821 // }
3822 // }
3823 // Err(wasmedge_wasi_nn::Error::BackendError(
3824 // wasmedge_wasi_nn::BackendError::ContextFull,
3825 // )) => {
3826 // #[cfg(feature = "logging")]
3827 // debug!(target: "stdout", "Context full");
3828
3829 // match context_full_state {
3830 // ContextFullState::Message => {
3831 // match include_usage {
3832 // true => *context_full_state = ContextFullState::Usage,
3833 // false => *context_full_state = ContextFullState::Done,
3834 // }
3835
3836 // let created = SystemTime::now()
3837 // .duration_since(std::time::UNIX_EPOCH)
3838 // .map_err(|e| {
3839 // let err_msg = format!(
3840 // "Failed to get the current time. Reason: {e}"
3841 // );
3842
3843 // #[cfg(feature = "logging")]
3844 // error!(target: "stdout", "{}", &err_msg);
3845
3846 // LlamaCoreError::Operation(err_msg)
3847 // })?;
3848
3849 // let chat_completion_chunk = ChatCompletionChunk {
3850 // id,
3851 // object: "chat.completion.chunk".to_string(),
3852 // created: created.as_secs(),
3853 // model: graph.name().to_owned(),
3854 // system_fingerprint: "fp_44709d6fcb".to_string(),
3855 // choices: vec![ChatCompletionChunkChoice {
3856 // index: 0,
3857 // delta: ChatCompletionChunkChoiceDelta {
3858 // role: ChatCompletionRole::Assistant,
3859 // content: Some(
3860 // "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
3861 // ),
3862 // tool_calls: vec![],
3863 // },
3864 // logprobs: None,
3865 // finish_reason: Some(FinishReason::length),
3866 // }],
3867 // usage: None,
3868 // };
3869
3870 // // serialize chat completion chunk
3871 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
3872 // .map_err(|e| {
3873 // let err_msg = format!(
3874 // "Failed to serialize chat completion chunk. Reason: {e}"
3875 // );
3876
3877 // #[cfg(feature = "logging")]
3878 // error!(target: "stdout", "{}", &err_msg);
3879
3880 // LlamaCoreError::Operation(err_msg)
3881 // })?;
3882
3883 // Ok(format!("data: {chunk_str}\n\n"))
3884 // }
3885 // ContextFullState::Usage => {
3886 // *context_full_state = ContextFullState::Done;
3887
3888 // // retrieve the number of prompt and completion tokens
3889 // let token_info = get_token_info_by_graph(graph)?;
3890
3891 // let usage = Some(Usage {
3892 // prompt_tokens: token_info.prompt_tokens,
3893 // completion_tokens: token_info.completion_tokens,
3894 // total_tokens: token_info.prompt_tokens
3895 // + token_info.completion_tokens,
3896 // });
3897
3898 // let created = SystemTime::now()
3899 // .duration_since(std::time::UNIX_EPOCH)
3900 // .map_err(|e| {
3901 // let err_msg = format!(
3902 // "Failed to get the current time. Reason: {e}"
3903 // );
3904
3905 // #[cfg(feature = "logging")]
3906 // error!(target: "stdout", "{}", &err_msg);
3907
3908 // LlamaCoreError::Operation(err_msg)
3909 // })?;
3910
3911 // let chat_completion_chunk = ChatCompletionChunk {
3912 // id,
3913 // object: "chat.completion.chunk".to_string(),
3914 // created: created.as_secs(),
3915 // model: graph.name().to_owned(),
3916 // system_fingerprint: "fp_44709d6fcb".to_string(),
3917 // choices: vec![],
3918 // usage,
3919 // };
3920
3921 // // serialize chat completion chunk
3922 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
3923 // .map_err(|e| {
3924 // let err_msg = format!(
3925 // "Failed to serialize chat completion chunk. Reason: {e}"
3926 // );
3927
3928 // #[cfg(feature = "logging")]
3929 // error!(target: "stdout", "{}", &err_msg);
3930
3931 // LlamaCoreError::Operation(err_msg)
3932 // })?;
3933
3934 // Ok(format!("data: {chunk_str}\n\n"))
3935 // }
3936 // ContextFullState::Done => {
3937 // *context_full_state = ContextFullState::EndOfSequence;
3938
3939 // Ok("data: [DONE]\n\n".to_string())
3940 // }
3941 // ContextFullState::EndOfSequence => {
3942 // Ok("[GGML] End of sequence".to_string())
3943 // }
3944 // }
3945 // }
3946 // Err(wasmedge_wasi_nn::Error::BackendError(
3947 // wasmedge_wasi_nn::BackendError::PromptTooLong,
3948 // )) => {
3949 // #[cfg(feature = "logging")]
3950 // debug!(target: "stdout", "Prompt too long");
3951
3952 // match prompt_too_long_state {
3953 // PromptTooLongState::Message => {
3954 // match include_usage {
3955 // true => *prompt_too_long_state = PromptTooLongState::Usage,
3956 // false => *prompt_too_long_state = PromptTooLongState::Done,
3957 // }
3958
3959 // let created = SystemTime::now()
3960 // .duration_since(std::time::UNIX_EPOCH)
3961 // .map_err(|e| {
3962 // let err_msg = format!(
3963 // "Failed to get the current time. Reason: {e}"
3964 // );
3965
3966 // #[cfg(feature = "logging")]
3967 // error!(target: "stdout", "{}", &err_msg);
3968
3969 // LlamaCoreError::Operation(err_msg)
3970 // })?;
3971
3972 // let chat_completion_chunk = ChatCompletionChunk {
3973 // id,
3974 // object: "chat.completion.chunk".to_string(),
3975 // created: created.as_secs(),
3976 // model: graph.name().to_owned(),
3977 // system_fingerprint: "fp_44709d6fcb".to_string(),
3978 // choices: vec![ChatCompletionChunkChoice {
3979 // index: 0,
3980 // delta: ChatCompletionChunkChoiceDelta {
3981 // role: ChatCompletionRole::Assistant,
3982 // content: None,
3983 // tool_calls: vec![],
3984 // },
3985 // logprobs: None,
3986 // finish_reason: Some(FinishReason::length),
3987 // }],
3988 // usage: None,
3989 // };
3990
3991 // // serialize chat completion chunk
3992 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
3993 // .map_err(|e| {
3994 // let err_msg = format!(
3995 // "Failed to serialize chat completion chunk. Reason: {e}"
3996 // );
3997
3998 // #[cfg(feature = "logging")]
3999 // error!(target: "stdout", "{}", &err_msg);
4000
4001 // LlamaCoreError::Operation(err_msg)
4002 // })?;
4003
4004 // Ok(format!("data: {chunk_str}\n\n"))
4005 // }
4006 // PromptTooLongState::Usage => {
4007 // *prompt_too_long_state = PromptTooLongState::Done;
4008
4009 // // retrieve the number of prompt and completion tokens
4010 // let token_info = get_token_info_by_graph(graph)?;
4011
4012 // let usage = Some(Usage {
4013 // prompt_tokens: token_info.prompt_tokens,
4014 // completion_tokens: token_info.completion_tokens,
4015 // total_tokens: token_info.prompt_tokens
4016 // + token_info.completion_tokens,
4017 // });
4018
4019 // let created = SystemTime::now()
4020 // .duration_since(std::time::UNIX_EPOCH)
4021 // .map_err(|e| {
4022 // let err_msg = format!(
4023 // "Failed to get the current time. Reason: {e}"
4024 // );
4025
4026 // #[cfg(feature = "logging")]
4027 // error!(target: "stdout", "{}", &err_msg);
4028
4029 // LlamaCoreError::Operation(err_msg)
4030 // })?;
4031
4032 // let chat_completion_chunk = ChatCompletionChunk {
4033 // id,
4034 // object: "chat.completion.chunk".to_string(),
4035 // created: created.as_secs(),
4036 // model: graph.name().to_owned(),
4037 // system_fingerprint: "fp_44709d6fcb".to_string(),
4038 // choices: vec![],
4039 // usage,
4040 // };
4041
4042 // // serialize chat completion chunk
4043 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
4044 // .map_err(|e| {
4045 // let err_msg = format!(
4046 // "Failed to serialize chat completion chunk. Reason: {e}"
4047 // );
4048
4049 // #[cfg(feature = "logging")]
4050 // error!(target: "stdout", "{}", &err_msg);
4051
4052 // LlamaCoreError::Operation(err_msg)
4053 // })?;
4054
4055 // Ok(format!("data: {chunk_str}\n\n"))
4056 // }
4057 // PromptTooLongState::Done => {
4058 // *prompt_too_long_state = PromptTooLongState::EndOfSequence;
4059
4060 // Ok("data: [DONE]\n\n".to_string())
4061 // }
4062 // PromptTooLongState::EndOfSequence => {
4063 // Ok("[GGML] End of sequence".to_string())
4064 // }
4065 // }
4066 // }
4067 // Err(e) => {
4068 // let err_msg =
4069 // format!("Failed to compute the chat completion. Reason: {e}");
4070
4071 // #[cfg(feature = "logging")]
4072 // error!(target: "stdout", "{}", &err_msg);
4073
4074 // Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4075 // err_msg,
4076 // )))
4077 // }
4078 // }
4079 // }
4080 // false => {
4081 // match chat_graphs.iter_mut().next() {
4082 // Some((_, graph)) => {
4083 // // compute
4084 // match graph.compute_single() {
4085 // Ok(_) => {
4086 // #[cfg(feature = "logging")]
4087 // debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4088
4089 // match stream_state {
4090 // StreamState::Usage | StreamState::NoUsage => {
4091 // // Retrieve the output
4092 // let output_buffer =
4093 // get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4094
4095 // #[cfg(feature = "logging")]
4096 // info!(target: "stdout", "retrieved the output buffer");
4097
4098 // // decode the output buffer to a utf8 string
4099 // let output = match String::from_utf8(
4100 // output_buffer.clone(),
4101 // ) {
4102 // Ok(token) => token,
4103 // Err(_) => {
4104 // let mutex = CACHED_UTF8_ENCODINGS
4105 // .get_or_init(|| Mutex::new(Vec::new()));
4106 // let mut cached_encodings = mutex.lock().map_err(|e| {
4107 // let err_msg = format!(
4108 // "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
4109 // );
4110
4111 // #[cfg(feature = "logging")]
4112 // error!(target: "stdout", "{}", &err_msg);
4113
4114 // LlamaCoreError::Operation(err_msg)
4115 // })?;
4116
4117 // // cache the bytes for future decoding
4118 // cached_encodings
4119 // .extend_from_slice(&output_buffer[..]);
4120
4121 // match String::from_utf8(
4122 // cached_encodings.to_vec(),
4123 // ) {
4124 // Ok(token) => {
4125 // // clear encodings
4126 // cached_encodings.clear();
4127
4128 // token
4129 // }
4130 // Err(e) => {
4131 // // TODO This is a temp check. In case, infinite cached encodings happen.
4132 // if cached_encodings.len() > 4 {
4133 // let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
4134
4135 // #[cfg(feature = "logging")]
4136 // error!(target: "stdout", "{}", &err_msg);
4137
4138 // #[cfg(feature = "logging")]
4139 // error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
4140
4141 // // let token =
4142 // // String::from_utf8_lossy(
4143 // // &cached_encodings,
4144 // // )
4145 // // .to_string();
4146
4147 // // clear CACHED_UTF8_ENCODINGS
4148 // cached_encodings.clear();
4149
4150 // String::from("")
4151 // } else {
4152 // let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
4153
4154 // #[cfg(feature = "logging")]
4155 // warn!(target: "stdout", "{}", &warn_msg);
4156
4157 // String::from("")
4158 // }
4159 // }
4160 // }
4161 // }
4162 // };
4163
4164 // #[cfg(feature = "logging")]
4165 // info!(target: "stdout", "decoded the output buffer");
4166
4167 // let created = SystemTime::now()
4168 // .duration_since(std::time::UNIX_EPOCH)
4169 // .map_err(|e| {
4170 // let err_msg = format!(
4171 // "Failed to get the current time. Reason: {e}"
4172 // );
4173
4174 // #[cfg(feature = "logging")]
4175 // error!(target: "stdout", "{}", &err_msg);
4176
4177 // LlamaCoreError::Operation(err_msg)
4178 // })?;
4179
4180 // let chat_completion_chunk = ChatCompletionChunk {
4181 // id,
4182 // object: "chat.completion.chunk".to_string(),
4183 // created: created.as_secs(),
4184 // model: graph.name().to_owned(),
4185 // system_fingerprint: "fp_44709d6fcb".to_string(),
4186 // choices: vec![ChatCompletionChunkChoice {
4187 // index: 0,
4188 // delta: ChatCompletionChunkChoiceDelta {
4189 // role: ChatCompletionRole::Assistant,
4190 // content: Some(output),
4191 // tool_calls: vec![],
4192 // },
4193 // logprobs: None,
4194 // finish_reason: None,
4195 // }],
4196 // usage: None,
4197 // };
4198
4199 // #[cfg(feature = "logging")]
4200 // info!(target: "stdout", "created chat completion chunk");
4201
4202 // // serialize chat completion chunk
4203 // let chunk_str =
4204 // serde_json::to_string(&chat_completion_chunk)
4205 // .map_err(|e| {
4206 // let err_msg = format!(
4207 // "Failed to serialize chat completion chunk. Reason: {e}"
4208 // );
4209
4210 // #[cfg(feature = "logging")]
4211 // error!(target: "stdout", "{}", &err_msg);
4212
4213 // LlamaCoreError::Operation(err_msg)
4214 // })?;
4215
4216 // Ok(format!("data: {chunk_str}\n\n"))
4217 // }
4218 // StreamState::Done => {
4219 // *stream_state = StreamState::EndOfSequence;
4220
4221 // Ok("data: [DONE]\n\n".to_string())
4222 // }
4223 // StreamState::EndOfSequence => {
4224 // Ok("[GGML] End of sequence".to_string())
4225 // }
4226 // }
4227 // }
4228 // Err(wasmedge_wasi_nn::Error::BackendError(
4229 // wasmedge_wasi_nn::BackendError::EndOfSequence,
4230 // )) => {
4231 // #[cfg(feature = "logging")]
4232 // debug!(target: "stdout", "End of sequence");
4233
4234 // match stream_state {
4235 // StreamState::Usage => {
4236 // *stream_state = StreamState::Done;
4237
4238 // // retrieve the number of prompt and completion tokens
4239 // let token_info = get_token_info_by_graph(graph)?;
4240
4241 // let usage = Some(Usage {
4242 // prompt_tokens: token_info.prompt_tokens,
4243 // completion_tokens: token_info.completion_tokens,
4244 // total_tokens: token_info.prompt_tokens
4245 // + token_info.completion_tokens,
4246 // });
4247
4248 // #[cfg(feature = "logging")]
4249 // info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4250
4251 // let created = SystemTime::now()
4252 // .duration_since(std::time::UNIX_EPOCH)
4253 // .map_err(|e| {
4254 // let err_msg = format!(
4255 // "Failed to get the current time. Reason: {e}"
4256 // );
4257
4258 // #[cfg(feature = "logging")]
4259 // error!(target: "stdout", "{}", &err_msg);
4260
4261 // LlamaCoreError::Operation(err_msg)
4262 // })?;
4263
4264 // let chat_completion_chunk = ChatCompletionChunk {
4265 // id,
4266 // object: "chat.completion.chunk".to_string(),
4267 // created: created.as_secs(),
4268 // model: graph.name().to_owned(),
4269 // system_fingerprint: "fp_44709d6fcb".to_string(),
4270 // choices: vec![],
4271 // usage,
4272 // };
4273
4274 // // serialize chat completion chunk
4275 // let chunk_str =
4276 // serde_json::to_string(&chat_completion_chunk)
4277 // .map_err(|e| {
4278 // let err_msg = format!(
4279 // "Failed to serialize chat completion chunk. Reason: {e}"
4280 // );
4281
4282 // #[cfg(feature = "logging")]
4283 // error!(target: "stdout", "{}", &err_msg);
4284
4285 // LlamaCoreError::Operation(err_msg)
4286 // })?;
4287
4288 // Ok(format!("data: {chunk_str}\n\n"))
4289 // }
4290 // StreamState::Done | StreamState::NoUsage => {
4291 // *stream_state = StreamState::EndOfSequence;
4292
4293 // Ok("data: [DONE]\n\n".to_string())
4294 // }
4295 // StreamState::EndOfSequence => {
4296 // Ok("[GGML] End of sequence".to_string())
4297 // }
4298 // }
4299 // }
4300 // Err(wasmedge_wasi_nn::Error::BackendError(
4301 // wasmedge_wasi_nn::BackendError::ContextFull,
4302 // )) => {
4303 // #[cfg(feature = "logging")]
4304 // debug!(target: "stdout", "Context full");
4305
4306 // match context_full_state {
4307 // ContextFullState::Message => {
4308 // match include_usage {
4309 // true => {
4310 // *context_full_state = ContextFullState::Usage
4311 // }
4312 // false => {
4313 // *context_full_state = ContextFullState::Done
4314 // }
4315 // }
4316
4317 // let created = SystemTime::now()
4318 // .duration_since(std::time::UNIX_EPOCH)
4319 // .map_err(|e| {
4320 // let err_msg = format!(
4321 // "Failed to get the current time. Reason: {e}"
4322 // );
4323
4324 // #[cfg(feature = "logging")]
4325 // error!(target: "stdout", "{}", &err_msg);
4326
4327 // LlamaCoreError::Operation(err_msg)
4328 // })?;
4329
4330 // let chat_completion_chunk = ChatCompletionChunk {
4331 // id,
4332 // object: "chat.completion.chunk".to_string(),
4333 // created: created.as_secs(),
4334 // model: graph.name().to_owned(),
4335 // system_fingerprint: "fp_44709d6fcb".to_string(),
4336 // choices: vec![ChatCompletionChunkChoice {
4337 // index: 0,
4338 // delta: ChatCompletionChunkChoiceDelta {
4339 // role: ChatCompletionRole::Assistant,
4340 // content: Some(
4341 // "<|WASMEDGE-GGML-CONTEXT-FULL|>"
4342 // .to_string(),
4343 // ),
4344 // tool_calls: vec![],
4345 // },
4346 // logprobs: None,
4347 // finish_reason: Some(FinishReason::length),
4348 // }],
4349 // usage: None,
4350 // };
4351
4352 // // serialize chat completion chunk
4353 // let chunk_str =
4354 // serde_json::to_string(&chat_completion_chunk)
4355 // .map_err(|e| {
4356 // let err_msg = format!(
4357 // "Failed to serialize chat completion chunk. Reason: {e}"
4358 // );
4359
4360 // #[cfg(feature = "logging")]
4361 // error!(target: "stdout", "{}", &err_msg);
4362
4363 // LlamaCoreError::Operation(err_msg)
4364 // })?;
4365
4366 // Ok(format!("data: {chunk_str}\n\n"))
4367 // }
4368 // ContextFullState::Usage => {
4369 // *context_full_state = ContextFullState::Done;
4370
4371 // // retrieve the number of prompt and completion tokens
4372 // let token_info = get_token_info_by_graph(graph)?;
4373
4374 // let usage = Some(Usage {
4375 // prompt_tokens: token_info.prompt_tokens,
4376 // completion_tokens: token_info.completion_tokens,
4377 // total_tokens: token_info.prompt_tokens
4378 // + token_info.completion_tokens,
4379 // });
4380
4381 // let created = SystemTime::now()
4382 // .duration_since(std::time::UNIX_EPOCH)
4383 // .map_err(|e| {
4384 // let err_msg = format!(
4385 // "Failed to get the current time. Reason: {e}"
4386 // );
4387
4388 // #[cfg(feature = "logging")]
4389 // error!(target: "stdout", "{}", &err_msg);
4390
4391 // LlamaCoreError::Operation(err_msg)
4392 // })?;
4393
4394 // let chat_completion_chunk = ChatCompletionChunk {
4395 // id,
4396 // object: "chat.completion.chunk".to_string(),
4397 // created: created.as_secs(),
4398 // model: graph.name().to_owned(),
4399 // system_fingerprint: "fp_44709d6fcb".to_string(),
4400 // choices: vec![],
4401 // usage,
4402 // };
4403
4404 // // serialize chat completion chunk
4405 // let chunk_str =
4406 // serde_json::to_string(&chat_completion_chunk)
4407 // .map_err(|e| {
4408 // let err_msg = format!(
4409 // "Failed to serialize chat completion chunk. Reason: {e}"
4410 // );
4411
4412 // #[cfg(feature = "logging")]
4413 // error!(target: "stdout", "{}", &err_msg);
4414
4415 // LlamaCoreError::Operation(err_msg)
4416 // })?;
4417
4418 // Ok(format!("data: {chunk_str}\n\n"))
4419 // }
4420 // ContextFullState::Done => {
4421 // *context_full_state = ContextFullState::EndOfSequence;
4422
4423 // Ok("data: [DONE]\n\n".to_string())
4424 // }
4425 // ContextFullState::EndOfSequence => {
4426 // Ok("[GGML] End of sequence".to_string())
4427 // }
4428 // }
4429 // }
4430 // Err(wasmedge_wasi_nn::Error::BackendError(
4431 // wasmedge_wasi_nn::BackendError::PromptTooLong,
4432 // )) => {
4433 // #[cfg(feature = "logging")]
4434 // debug!(target: "stdout", "Prompt too long");
4435
4436 // match prompt_too_long_state {
4437 // PromptTooLongState::Message => {
4438 // match include_usage {
4439 // true => {
4440 // *prompt_too_long_state =
4441 // PromptTooLongState::Usage
4442 // }
4443 // false => {
4444 // *prompt_too_long_state =
4445 // PromptTooLongState::Done
4446 // }
4447 // }
4448
4449 // let created = SystemTime::now()
4450 // .duration_since(std::time::UNIX_EPOCH)
4451 // .map_err(|e| {
4452 // let err_msg = format!(
4453 // "Failed to get the current time. Reason: {e}"
4454 // );
4455
4456 // #[cfg(feature = "logging")]
4457 // error!(target: "stdout", "{}", &err_msg);
4458
4459 // LlamaCoreError::Operation(err_msg)
4460 // })?;
4461
4462 // let chat_completion_chunk = ChatCompletionChunk {
4463 // id,
4464 // object: "chat.completion.chunk".to_string(),
4465 // created: created.as_secs(),
4466 // model: graph.name().to_owned(),
4467 // system_fingerprint: "fp_44709d6fcb".to_string(),
4468 // choices: vec![ChatCompletionChunkChoice {
4469 // index: 0,
4470 // delta: ChatCompletionChunkChoiceDelta {
4471 // role: ChatCompletionRole::Assistant,
4472 // content: None,
4473 // tool_calls: vec![],
4474 // },
4475 // logprobs: None,
4476 // finish_reason: Some(FinishReason::length),
4477 // }],
4478 // usage: None,
4479 // };
4480
4481 // // serialize chat completion chunk
4482 // let chunk_str =
4483 // serde_json::to_string(&chat_completion_chunk)
4484 // .map_err(|e| {
4485 // let err_msg = format!(
4486 // "Failed to serialize chat completion chunk. Reason: {e}"
4487 // );
4488
4489 // #[cfg(feature = "logging")]
4490 // error!(target: "stdout", "{}", &err_msg);
4491
4492 // LlamaCoreError::Operation(err_msg)
4493 // })?;
4494
4495 // Ok(format!("data: {chunk_str}\n\n"))
4496 // }
4497 // PromptTooLongState::Usage => {
4498 // *prompt_too_long_state = PromptTooLongState::Done;
4499
4500 // // retrieve the number of prompt and completion tokens
4501 // let token_info = get_token_info_by_graph(graph)?;
4502
4503 // let usage = Some(Usage {
4504 // prompt_tokens: token_info.prompt_tokens,
4505 // completion_tokens: token_info.completion_tokens,
4506 // total_tokens: token_info.prompt_tokens
4507 // + token_info.completion_tokens,
4508 // });
4509
4510 // let created = SystemTime::now()
4511 // .duration_since(std::time::UNIX_EPOCH)
4512 // .map_err(|e| {
4513 // let err_msg = format!(
4514 // "Failed to get the current time. Reason: {e}"
4515 // );
4516
4517 // #[cfg(feature = "logging")]
4518 // error!(target: "stdout", "{}", &err_msg);
4519
4520 // LlamaCoreError::Operation(err_msg)
4521 // })?;
4522
4523 // let chat_completion_chunk = ChatCompletionChunk {
4524 // id,
4525 // object: "chat.completion.chunk".to_string(),
4526 // created: created.as_secs(),
4527 // model: graph.name().to_owned(),
4528 // system_fingerprint: "fp_44709d6fcb".to_string(),
4529 // choices: vec![],
4530 // usage,
4531 // };
4532
4533 // // serialize chat completion chunk
4534 // let chunk_str =
4535 // serde_json::to_string(&chat_completion_chunk)
4536 // .map_err(|e| {
4537 // let err_msg = format!(
4538 // "Failed to serialize chat completion chunk. Reason: {e}"
4539 // );
4540
4541 // #[cfg(feature = "logging")]
4542 // error!(target: "stdout", "{}", &err_msg);
4543
4544 // LlamaCoreError::Operation(err_msg)
4545 // })?;
4546
4547 // Ok(format!("data: {chunk_str}\n\n"))
4548 // }
4549 // PromptTooLongState::Done => {
4550 // *prompt_too_long_state =
4551 // PromptTooLongState::EndOfSequence;
4552
4553 // Ok("data: [DONE]\n\n".to_string())
4554 // }
4555 // PromptTooLongState::EndOfSequence => {
4556 // Ok("[GGML] End of sequence".to_string())
4557 // }
4558 // }
4559 // }
4560 // Err(e) => {
4561 // let err_msg = format!(
4562 // "Failed to compute the chat completion. Reason: {e}"
4563 // );
4564
4565 // #[cfg(feature = "logging")]
4566 // error!(target: "stdout", "{}", &err_msg);
4567
4568 // Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4569 // err_msg,
4570 // )))
4571 // }
4572 // }
4573 // }
4574 // None => {
4575 // let err_msg = "There is no model available in the chat graphs.";
4576
4577 // #[cfg(feature = "logging")]
4578 // error!(target: "stdout", "{}", &err_msg);
4579
4580 // Err(LlamaCoreError::Operation(err_msg.into()))
4581 // }
4582 // }
4583 // }
4584 // }
4585 // }
4586 // None => {
4587 // match chat_graphs.iter_mut().next() {
4588 // Some((_, graph)) => {
4589 // // compute
4590 // match graph.compute_single() {
4591 // Ok(_) => {
4592 // #[cfg(feature = "logging")]
4593 // debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4594
4595 // match stream_state {
4596 // StreamState::Usage | StreamState::NoUsage => {
4597 // // Retrieve the output
4598 // let output_buffer =
4599 // get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4600
4601 // #[cfg(feature = "logging")]
4602 // info!(target: "stdout", "retrieved the output buffer");
4603
4604 // // decode the output buffer to a utf8 string
4605 // let output = match String::from_utf8(output_buffer.clone()) {
4606 // Ok(token) => token,
4607 // Err(_) => {
4608 // let mutex = CACHED_UTF8_ENCODINGS
4609 // .get_or_init(|| Mutex::new(Vec::new()));
4610 // let mut cached_encodings = mutex.lock().map_err(|e| {
4611 // let err_msg = format!(
4612 // "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
4613 // );
4614
4615 // #[cfg(feature = "logging")]
4616 // error!(target: "stdout", "{}", &err_msg);
4617
4618 // LlamaCoreError::Operation(err_msg)
4619 // })?;
4620
4621 // cached_encodings.extend_from_slice(&output_buffer[..]);
4622
4623 // match String::from_utf8(cached_encodings.to_vec()) {
4624 // Ok(token) => {
4625 // // clear encodings
4626 // cached_encodings.clear();
4627
4628 // token
4629 // }
4630 // Err(e) => {
4631 // // TODO This is a temp check. In case, infinite cached encodings happen.
4632 // if cached_encodings.len() > 4 {
4633 // let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
4634
4635 // #[cfg(feature = "logging")]
4636 // error!(target: "stdout", "{}", &err_msg);
4637
4638 // #[cfg(feature = "logging")]
4639 // error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
4640
4641 // // let token = String::from_utf8_lossy(
4642 // // &cached_encodings,
4643 // // )
4644 // // .to_string();
4645
4646 // // clear CACHED_UTF8_ENCODINGS
4647 // cached_encodings.clear();
4648
4649 // String::from("")
4650 // } else {
4651 // let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
4652
4653 // #[cfg(feature = "logging")]
4654 // warn!(target: "stdout", "{}", &warn_msg);
4655
4656 // String::from("")
4657 // }
4658 // }
4659 // }
4660 // }
4661 // };
4662
4663 // #[cfg(feature = "logging")]
4664 // info!(target: "stdout", "decoded the output buffer");
4665
4666 // let created = SystemTime::now()
4667 // .duration_since(std::time::UNIX_EPOCH)
4668 // .map_err(|e| {
4669 // let err_msg = format!(
4670 // "Failed to get the current time. Reason: {e}"
4671 // );
4672
4673 // #[cfg(feature = "logging")]
4674 // error!(target: "stdout", "{}", &err_msg);
4675
4676 // LlamaCoreError::Operation(err_msg)
4677 // })?;
4678
4679 // let chat_completion_chunk = ChatCompletionChunk {
4680 // id,
4681 // object: "chat.completion.chunk".to_string(),
4682 // created: created.as_secs(),
4683 // model: graph.name().to_owned(),
4684 // system_fingerprint: "fp_44709d6fcb".to_string(),
4685 // choices: vec![ChatCompletionChunkChoice {
4686 // index: 0,
4687 // delta: ChatCompletionChunkChoiceDelta {
4688 // role: ChatCompletionRole::Assistant,
4689 // content: Some(output),
4690 // tool_calls: vec![],
4691 // },
4692 // logprobs: None,
4693 // finish_reason: None,
4694 // }],
4695 // usage: None,
4696 // };
4697
4698 // #[cfg(feature = "logging")]
4699 // info!(target: "stdout", "created chat completion chunk");
4700
4701 // // serialize chat completion chunk
4702 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
4703 // .map_err(|e| {
4704 // let err_msg = format!(
4705 // "Failed to serialize chat completion chunk. Reason: {e}"
4706 // );
4707
4708 // #[cfg(feature = "logging")]
4709 // error!(target: "stdout", "{}", &err_msg);
4710
4711 // LlamaCoreError::Operation(err_msg)
4712 // })?;
4713
4714 // Ok(format!("data: {chunk_str}\n\n"))
4715 // }
4716 // StreamState::Done => {
4717 // *stream_state = StreamState::EndOfSequence;
4718
4719 // Ok("data: [DONE]\n\n".to_string())
4720 // }
4721 // StreamState::EndOfSequence => {
4722 // Ok("[GGML] End of sequence".to_string())
4723 // }
4724 // }
4725 // }
4726 // Err(wasmedge_wasi_nn::Error::BackendError(
4727 // wasmedge_wasi_nn::BackendError::EndOfSequence,
4728 // )) => {
4729 // #[cfg(feature = "logging")]
4730 // debug!(target: "stdout", "End of sequence");
4731
4732 // match stream_state {
4733 // StreamState::Usage => {
4734 // *stream_state = StreamState::Done;
4735
4736 // // retrieve the number of prompt and completion tokens
4737 // let token_info = get_token_info_by_graph(graph)?;
4738
4739 // let usage = Some(Usage {
4740 // prompt_tokens: token_info.prompt_tokens,
4741 // completion_tokens: token_info.completion_tokens,
4742 // total_tokens: token_info.prompt_tokens
4743 // + token_info.completion_tokens,
4744 // });
4745
4746 // #[cfg(feature = "logging")]
4747 // info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4748
4749 // let created = SystemTime::now()
4750 // .duration_since(std::time::UNIX_EPOCH)
4751 // .map_err(|e| {
4752 // let err_msg = format!(
4753 // "Failed to get the current time. Reason: {e}"
4754 // );
4755
4756 // #[cfg(feature = "logging")]
4757 // error!(target: "stdout", "{}", &err_msg);
4758
4759 // LlamaCoreError::Operation(err_msg)
4760 // })?;
4761
4762 // let chat_completion_chunk = ChatCompletionChunk {
4763 // id,
4764 // object: "chat.completion.chunk".to_string(),
4765 // created: created.as_secs(),
4766 // model: graph.name().to_owned(),
4767 // system_fingerprint: "fp_44709d6fcb".to_string(),
4768 // choices: vec![],
4769 // usage,
4770 // };
4771
4772 // // serialize chat completion chunk
4773 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
4774 // .map_err(|e| {
4775 // let err_msg = format!(
4776 // "Failed to serialize chat completion chunk. Reason: {e}"
4777 // );
4778
4779 // #[cfg(feature = "logging")]
4780 // error!(target: "stdout", "{}", &err_msg);
4781
4782 // LlamaCoreError::Operation(err_msg)
4783 // })?;
4784
4785 // Ok(format!("data: {chunk_str}\n\n"))
4786 // }
4787 // StreamState::Done | StreamState::NoUsage => {
4788 // *stream_state = StreamState::EndOfSequence;
4789
4790 // Ok("data: [DONE]\n\n".to_string())
4791 // }
4792 // StreamState::EndOfSequence => {
4793 // Ok("[GGML] End of sequence".to_string())
4794 // }
4795 // }
4796 // }
4797 // Err(wasmedge_wasi_nn::Error::BackendError(
4798 // wasmedge_wasi_nn::BackendError::ContextFull,
4799 // )) => {
4800 // #[cfg(feature = "logging")]
4801 // debug!(target: "stdout", "Context full");
4802
4803 // match context_full_state {
4804 // ContextFullState::Message => {
4805 // match include_usage {
4806 // true => *context_full_state = ContextFullState::Usage,
4807 // false => *context_full_state = ContextFullState::Done,
4808 // }
4809
4810 // let created = SystemTime::now()
4811 // .duration_since(std::time::UNIX_EPOCH)
4812 // .map_err(|e| {
4813 // let err_msg = format!(
4814 // "Failed to get the current time. Reason: {e}"
4815 // );
4816
4817 // #[cfg(feature = "logging")]
4818 // error!(target: "stdout", "{}", &err_msg);
4819
4820 // LlamaCoreError::Operation(err_msg)
4821 // })?;
4822
4823 // let chat_completion_chunk = ChatCompletionChunk {
4824 // id,
4825 // object: "chat.completion.chunk".to_string(),
4826 // created: created.as_secs(),
4827 // model: graph.name().to_owned(),
4828 // system_fingerprint: "fp_44709d6fcb".to_string(),
4829 // choices: vec![ChatCompletionChunkChoice {
4830 // index: 0,
4831 // delta: ChatCompletionChunkChoiceDelta {
4832 // role: ChatCompletionRole::Assistant,
4833 // content: Some(
4834 // "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
4835 // ),
4836 // tool_calls: vec![],
4837 // },
4838 // logprobs: None,
4839 // finish_reason: Some(FinishReason::length),
4840 // }],
4841 // usage: None,
4842 // };
4843
4844 // // serialize chat completion chunk
4845 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
4846 // .map_err(|e| {
4847 // let err_msg = format!(
4848 // "Failed to serialize chat completion chunk. Reason: {e}"
4849 // );
4850
4851 // #[cfg(feature = "logging")]
4852 // error!(target: "stdout", "{}", &err_msg);
4853
4854 // LlamaCoreError::Operation(err_msg)
4855 // })?;
4856
4857 // Ok(format!("data: {chunk_str}\n\n"))
4858 // }
4859 // ContextFullState::Usage => {
4860 // *context_full_state = ContextFullState::Done;
4861
4862 // // retrieve the number of prompt and completion tokens
4863 // let token_info = get_token_info_by_graph(graph)?;
4864
4865 // let usage = Some(Usage {
4866 // prompt_tokens: token_info.prompt_tokens,
4867 // completion_tokens: token_info.completion_tokens,
4868 // total_tokens: token_info.prompt_tokens
4869 // + token_info.completion_tokens,
4870 // });
4871
4872 // let created = SystemTime::now()
4873 // .duration_since(std::time::UNIX_EPOCH)
4874 // .map_err(|e| {
4875 // let err_msg = format!(
4876 // "Failed to get the current time. Reason: {e}"
4877 // );
4878
4879 // #[cfg(feature = "logging")]
4880 // error!(target: "stdout", "{}", &err_msg);
4881
4882 // LlamaCoreError::Operation(err_msg)
4883 // })?;
4884
4885 // let chat_completion_chunk = ChatCompletionChunk {
4886 // id,
4887 // object: "chat.completion.chunk".to_string(),
4888 // created: created.as_secs(),
4889 // model: graph.name().to_owned(),
4890 // system_fingerprint: "fp_44709d6fcb".to_string(),
4891 // choices: vec![],
4892 // usage,
4893 // };
4894
4895 // // serialize chat completion chunk
4896 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
4897 // .map_err(|e| {
4898 // let err_msg = format!(
4899 // "Failed to serialize chat completion chunk. Reason: {e}"
4900 // );
4901
4902 // #[cfg(feature = "logging")]
4903 // error!(target: "stdout", "{}", &err_msg);
4904
4905 // LlamaCoreError::Operation(err_msg)
4906 // })?;
4907
4908 // Ok(format!("data: {chunk_str}\n\n"))
4909 // }
4910 // ContextFullState::Done => {
4911 // *context_full_state = ContextFullState::EndOfSequence;
4912
4913 // Ok("data: [DONE]\n\n".to_string())
4914 // }
4915 // ContextFullState::EndOfSequence => {
4916 // Ok("[GGML] End of sequence".to_string())
4917 // }
4918 // }
4919 // }
4920 // Err(wasmedge_wasi_nn::Error::BackendError(
4921 // wasmedge_wasi_nn::BackendError::PromptTooLong,
4922 // )) => {
4923 // #[cfg(feature = "logging")]
4924 // debug!(target: "stdout", "Prompt too long");
4925
4926 // match prompt_too_long_state {
4927 // PromptTooLongState::Message => {
4928 // match include_usage {
4929 // true => *prompt_too_long_state = PromptTooLongState::Usage,
4930 // false => *prompt_too_long_state = PromptTooLongState::Done,
4931 // }
4932
4933 // let created = SystemTime::now()
4934 // .duration_since(std::time::UNIX_EPOCH)
4935 // .map_err(|e| {
4936 // let err_msg = format!(
4937 // "Failed to get the current time. Reason: {e}"
4938 // );
4939
4940 // #[cfg(feature = "logging")]
4941 // error!(target: "stdout", "{}", &err_msg);
4942
4943 // LlamaCoreError::Operation(err_msg)
4944 // })?;
4945
4946 // let chat_completion_chunk = ChatCompletionChunk {
4947 // id,
4948 // object: "chat.completion.chunk".to_string(),
4949 // created: created.as_secs(),
4950 // model: graph.name().to_owned(),
4951 // system_fingerprint: "fp_44709d6fcb".to_string(),
4952 // choices: vec![ChatCompletionChunkChoice {
4953 // index: 0,
4954 // delta: ChatCompletionChunkChoiceDelta {
4955 // role: ChatCompletionRole::Assistant,
4956 // content: None,
4957 // tool_calls: vec![],
4958 // },
4959 // logprobs: None,
4960 // finish_reason: Some(FinishReason::length),
4961 // }],
4962 // usage: None,
4963 // };
4964
4965 // // serialize chat completion chunk
4966 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
4967 // .map_err(|e| {
4968 // let err_msg = format!(
4969 // "Failed to serialize chat completion chunk. Reason: {e}"
4970 // );
4971
4972 // #[cfg(feature = "logging")]
4973 // error!(target: "stdout", "{}", &err_msg);
4974
4975 // LlamaCoreError::Operation(err_msg)
4976 // })?;
4977
4978 // Ok(format!("data: {chunk_str}\n\n"))
4979 // }
4980 // PromptTooLongState::Usage => {
4981 // *prompt_too_long_state = PromptTooLongState::Done;
4982
4983 // // retrieve the number of prompt and completion tokens
4984 // let token_info = get_token_info_by_graph(graph)?;
4985
4986 // let usage = Some(Usage {
4987 // prompt_tokens: token_info.prompt_tokens,
4988 // completion_tokens: token_info.completion_tokens,
4989 // total_tokens: token_info.prompt_tokens
4990 // + token_info.completion_tokens,
4991 // });
4992
4993 // let created = SystemTime::now()
4994 // .duration_since(std::time::UNIX_EPOCH)
4995 // .map_err(|e| {
4996 // let err_msg = format!(
4997 // "Failed to get the current time. Reason: {e}"
4998 // );
4999
5000 // #[cfg(feature = "logging")]
5001 // error!(target: "stdout", "{}", &err_msg);
5002
5003 // LlamaCoreError::Operation(err_msg)
5004 // })?;
5005
5006 // let chat_completion_chunk = ChatCompletionChunk {
5007 // id,
5008 // object: "chat.completion.chunk".to_string(),
5009 // created: created.as_secs(),
5010 // model: graph.name().to_owned(),
5011 // system_fingerprint: "fp_44709d6fcb".to_string(),
5012 // choices: vec![],
5013 // usage,
5014 // };
5015
5016 // // serialize chat completion chunk
5017 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
5018 // .map_err(|e| {
5019 // let err_msg = format!(
5020 // "Failed to serialize chat completion chunk. Reason: {e}"
5021 // );
5022
5023 // #[cfg(feature = "logging")]
5024 // error!(target: "stdout", "{}", &err_msg);
5025
5026 // LlamaCoreError::Operation(err_msg)
5027 // })?;
5028
5029 // Ok(format!("data: {chunk_str}\n\n"))
5030 // }
5031 // PromptTooLongState::Done => {
5032 // *prompt_too_long_state = PromptTooLongState::EndOfSequence;
5033
5034 // Ok("data: [DONE]\n\n".to_string())
5035 // }
5036 // PromptTooLongState::EndOfSequence => {
5037 // Ok("[GGML] End of sequence".to_string())
5038 // }
5039 // }
5040 // }
5041 // Err(e) => {
5042 // let err_msg =
5043 // format!("Failed to compute the chat completion. Reason: {e}");
5044
5045 // #[cfg(feature = "logging")]
5046 // error!(target: "stdout", "{}", &err_msg);
5047
5048 // Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
5049 // err_msg,
5050 // )))
5051 // }
5052 // }
5053 // }
5054 // None => {
5055 // let err_msg = "There is no model available in the chat graphs.";
5056
5057 // #[cfg(feature = "logging")]
5058 // error!(target: "stdout", "{}", &err_msg);
5059
5060 // Err(LlamaCoreError::Operation(err_msg.into()))
5061 // }
5062 // }
5063 // }
5064 // };
5065
5066 // #[cfg(feature = "logging")]
5067 // info!(target: "stdout", "Return the chat stream chunk!");
5068
5069 // res
5070 }
5071
5072 todo!("stream_chat_completion is not implemented yet");
5073}
5074
5075// #[allow(dead_code)]
5076// #[derive(Debug)]
5077// struct ParseResult {
5078// raw: String,
5079// content: Option<String>,
5080// tool_calls: Vec<ToolCall>,
5081// }
5082
5083/// Convert Input to a vector of ChatCompletionRequestMessage
5084/// Only handles InputItem::InputMessage, other variants are skipped with warnings
5085fn to_chat_messages(input: &Input) -> Result<Vec<ChatCompletionRequestMessage>, LlamaCoreError> {
5086 match input {
5087 Input::Text(text) => {
5088 let content = ChatCompletionUserMessageContent::Text(text.clone());
5089 let user_message = ChatCompletionRequestMessage::new_user_message(content, None);
5090 // Simple text converts to a user message
5091 Ok(vec![user_message])
5092 }
5093 Input::InputItemList(items) => {
5094 let mut messages = Vec::new();
5095 for item in items {
5096 match item {
5097 InputItem::InputMessage { content, role, .. } => {
5098 let message = input_message_to_chat_message(content, role)?;
5099 messages.push(message);
5100 }
5101 _ => {
5102 #[cfg(feature = "logging")]
5103 warn!(target: "stdout", "Skipping unsupported InputItem variant");
5104 }
5105 }
5106 }
5107 Ok(messages)
5108 }
5109 }
5110}
5111
5112/// Helper function to convert InputMessage to ChatCompletionRequestMessage
5113fn input_message_to_chat_message(
5114 content: &InputMessageContent,
5115 role: &str,
5116) -> Result<ChatCompletionRequestMessage, LlamaCoreError> {
5117 match role {
5118 "user" => {
5119 let content = input_message_content_to_chat_message_content(content)?;
5120 let content = ChatCompletionUserMessageContent::Text(content);
5121 Ok(ChatCompletionRequestMessage::new_user_message(
5122 content, None,
5123 ))
5124 }
5125 "assistant" => {
5126 let content = input_message_content_to_chat_message_content(content)?;
5127 Ok(ChatCompletionRequestMessage::new_assistant_message(
5128 Some(content),
5129 None,
5130 None,
5131 ))
5132 }
5133 "system" => {
5134 let content = input_message_content_to_chat_message_content(content)?;
5135 Ok(ChatCompletionRequestMessage::new_system_message(
5136 content, None,
5137 ))
5138 }
5139 "developer" => {
5140 let content = input_message_content_to_chat_message_content(content)?;
5141 Ok(ChatCompletionRequestMessage::new_developer_message(
5142 content, None,
5143 ))
5144 }
5145 _ => {
5146 let error_msg = format!("Unsupported role: {}", role);
5147
5148 #[cfg(feature = "logging")]
5149 error!(target: "stdout", "{}", &error_msg);
5150
5151 Err(LlamaCoreError::Operation(error_msg))
5152 }
5153 }
5154}
5155
5156/// Convert InputMessageContent to ChatCompletionUserMessageContent
5157fn input_message_content_to_chat_message_content(
5158 content: &InputMessageContent,
5159) -> Result<String, LlamaCoreError> {
5160 match content {
5161 InputMessageContent::Text(text) => Ok(text.clone()),
5162 InputMessageContent::InputItemContentList(_items) => {
5163 let error_msg = "Not support InputMessageContent::InputItemContentList";
5164
5165 #[cfg(feature = "logging")]
5166 error!(target: "stdout", "{}", error_msg);
5167
5168 Err(LlamaCoreError::Operation(error_msg.into()))
5169 }
5170 }
5171}