llama_core/chat/responses.rs
1//! Define APIs for chat completion.
2#[allow(unused_imports)]
3use crate::{
4 error,
5 metadata::ggml::GgmlMetadata,
6 running_mode,
7 utils::{
8 gen_chat_id, gen_response_id, get_output_buffer, get_output_buffer_single,
9 get_token_info_by_graph, get_token_info_by_graph_name, set_tensor_data_u8,
10 },
11 Graph, RunningMode, CHAT_GRAPHS, OUTPUT_TENSOR,
12};
13use chat_prompts::{BuildChatPrompt, ChatPrompt, PromptTemplateType};
14use either::{Either, Left, Right};
15use endpoints::{
16 // chat::{
17 // ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionChunkChoiceDelta,
18 // ChatCompletionObject, ChatCompletionObjectChoice, ChatCompletionObjectMessage,
19 // ChatCompletionRequest, ChatCompletionRequestMessage, ChatCompletionRole,
20 // ChatCompletionUserMessageContent, ContentPart, Function, ToolCall, ToolCallForChunk,
21 // ToolChoice,
22 // },
23 chat::{ChatCompletionRequestMessage, ChatCompletionRole, ChatCompletionUserMessageContent},
24 responses::{
25 items::{ResponseOutputItem, ResponseOutputItemOutputMessageContent},
26 response_object::{
27 Input, InputItem, InputMessageContent, InputTokensDetails, OutputTokensDetails,
28 RequestOfModelResponse, ResponseObject, ToolChoice, Usage,
29 },
30 },
31 // common::{FinishReason, Usage},
32};
33use error::{BackendError, LlamaCoreError};
34use std::{
35 collections::{HashMap, VecDeque},
36 pin::Pin,
37 sync::{
38 atomic::{AtomicBool, Ordering},
39 Mutex, OnceLock,
40 },
41 task::{Context, Poll, Waker},
42 time::SystemTime,
43};
44
45// Define a global waker queue for storing waiting ChatStreams
46static CHAT_STREAM_WAKER_QUEUE: OnceLock<Mutex<VecDeque<Waker>>> = OnceLock::new();
47
48// Define a global atomic boolean indicating whether there is an active ChatStream
49static CHAT_STREAM_ACTIVE: AtomicBool = AtomicBool::new(false);
50
51/// Processes a chat-completion request and returns either a stream of ChatCompletionChunk instances or a ChatCompletionObject instance.
52pub async fn chat(
53 chat_request: &mut RequestOfModelResponse,
54) -> Result<
55 (
56 Either<impl futures::TryStream<Ok = String, Error = LlamaCoreError>, ResponseObject>,
57 bool,
58 ),
59 LlamaCoreError,
60> {
61 #[cfg(feature = "logging")]
62 {
63 debug!(target: "stdout", "tool choice: {:?}", &chat_request.tool_choice);
64 debug!(target: "stdout", "tools: {:?}", chat_request.tools.as_ref());
65 debug!(target: "stdout", "stream mode: {:?}", chat_request.stream);
66 }
67
68 let result = match chat_request.stream {
69 Some(true) => match chat_stream(chat_request).await {
70 Ok((stream, include_tool_calls)) => Ok((Left(stream), include_tool_calls)),
71 Err(e) => Err(e),
72 },
73 Some(false) | None => match chat_once(chat_request).await {
74 Ok((chat_completion_object, include_tool_calls)) => {
75 Ok((Right(chat_completion_object), include_tool_calls))
76 }
77 Err(e) => Err(e),
78 },
79 };
80
81 #[cfg(feature = "logging")]
82 info!(target: "stdout", "Reset the model metadata");
83
84 result
85}
86
87async fn chat_stream(
88 chat_request: &mut RequestOfModelResponse,
89) -> Result<
90 (
91 impl futures::TryStream<Ok = String, Error = LlamaCoreError>,
92 bool,
93 ),
94 LlamaCoreError,
95> {
96 #[cfg(feature = "logging")]
97 info!(target: "stdout", "Process chat completion request in the stream mode");
98
99 let running_mode = running_mode()?;
100 if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
101 let err_msg = "The chat completion is only supported in the chat or rag mode.";
102
103 #[cfg(feature = "logging")]
104 error!(target: "stdout", "{err_msg}");
105
106 return Err(LlamaCoreError::Operation(err_msg.to_string()));
107 }
108
109 let model_name = chat_request.model.clone();
110
111 #[cfg(feature = "logging")]
112 info!(target: "stdout", "Check model metadata");
113
114 // update metadata
115 let mut metadata = check_model_metadata(chat_request)?;
116
117 // parse the `include_usage` option
118 // let include_usage = match chat_request.stream_options {
119 // Some(ref stream_options) => stream_options.include_usage.unwrap_or_default(),
120 // None => metadata.include_usage,
121 // };
122 // #[cfg(feature = "logging")]
123 // info!(target: "stdout", "include_usage: {include_usage}");
124
125 #[cfg(feature = "logging")]
126 info!(target: "stdout", "Build the chat prompt");
127
128 // build prompt
129 let (prompt, avaible_completion_tokens, tool_use) =
130 build_prompt(model_name.as_ref(), chat_request)?;
131
132 #[cfg(feature = "logging")]
133 {
134 info!(target: "stdout", "prompt:\n{}", &prompt);
135 info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
136 info!(target: "stdout", "tool_use: {tool_use}");
137 }
138
139 #[cfg(feature = "logging")]
140 info!(target: "stdout", "Update the n_predict");
141
142 // update metadata n_predict
143 update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
144
145 #[cfg(feature = "logging")]
146 info!(target: "stdout", "Feed the prompt to the model");
147
148 // set prompt
149 set_prompt(chat_request.model.as_ref(), &prompt)?;
150
151 let stream = match tool_use {
152 false => (ChatStream::new(model_name, None), false),
153 true => {
154 todo!("Implement the streaming with tool use")
155
156 // let chat_graphs = match CHAT_GRAPHS.get() {
157 // Some(chat_graphs) => chat_graphs,
158 // None => {
159 // let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
160
161 // #[cfg(feature = "logging")]
162 // error!(target: "stdout", "{}", &err_msg);
163
164 // return Err(LlamaCoreError::Operation(err_msg.into()));
165 // }
166 // };
167
168 // let mut chat_graphs = chat_graphs.lock().map_err(|e| {
169 // let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
170
171 // #[cfg(feature = "logging")]
172 // error!(target: "stdout", "{}", &err_msg);
173
174 // LlamaCoreError::Operation(err_msg)
175 // })?;
176
177 // match model_name {
178 // Some(model_name) => match chat_graphs.contains_key(&model_name) {
179 // true => {
180 // let graph = chat_graphs.get_mut(&model_name).unwrap();
181 // chat_stream_for_tool(graph, id, include_usage)?
182 // }
183 // false => match chat_graphs.iter_mut().next() {
184 // Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
185 // None => {
186 // let err_msg = "There is no model available in the chat graphs.";
187
188 // #[cfg(feature = "logging")]
189 // error!(target: "stdout", "{}", &err_msg);
190
191 // return Err(LlamaCoreError::Operation(err_msg.into()));
192 // }
193 // },
194 // },
195 // None => match chat_graphs.iter_mut().next() {
196 // Some((_, graph)) => chat_stream_for_tool(graph, id, include_usage)?,
197 // None => {
198 // let err_msg = "There is no model available in the chat graphs.";
199
200 // #[cfg(feature = "logging")]
201 // error!(target: "stdout", "{}", &err_msg);
202
203 // return Err(LlamaCoreError::Operation(err_msg.into()));
204 // }
205 // },
206 // }
207 }
208 };
209
210 #[cfg(feature = "logging")]
211 info!(target: "stdout", "End of the chat completion stream.");
212
213 Ok(stream)
214}
215
216async fn chat_once(
217 chat_request: &mut RequestOfModelResponse,
218) -> Result<(ResponseObject, bool), LlamaCoreError> {
219 #[cfg(feature = "logging")]
220 info!(target: "stdout", "Processing chat completion request in non-stream mode");
221
222 let running_mode = running_mode()?;
223 if !running_mode.contains(RunningMode::CHAT) && !running_mode.contains(RunningMode::RAG) {
224 let err_msg = "The chat completion is only supported in the chat or rag mode.";
225
226 #[cfg(feature = "logging")]
227 error!(target: "stdout", "{err_msg}");
228
229 return Err(LlamaCoreError::Operation(err_msg.to_string()));
230 }
231
232 let model_name = chat_request.model.clone();
233
234 // let id = match &chat_request.user {
235 // Some(id) => id.clone(),
236 // None => gen_chat_id(),
237 // };
238
239 // #[cfg(feature = "logging")]
240 // info!(target: "stdout", "user: {}", &id);
241
242 #[cfg(feature = "logging")]
243 info!(target: "stdout", "Check model metadata");
244
245 // update metadata
246 let mut metadata = check_model_metadata(chat_request)?;
247
248 #[cfg(feature = "logging")]
249 info!(target: "stdout", "Build the chat prompt");
250
251 // build prompt
252 let (prompt, avaible_completion_tokens, tool_use) =
253 build_prompt(model_name.as_ref(), chat_request)?;
254
255 #[cfg(feature = "logging")]
256 {
257 info!(target: "stdout", "prompt:\n{}", &prompt);
258 info!(target: "stdout", "available_completion_tokens: {avaible_completion_tokens}");
259 info!(target: "stdout", "tool_use: {tool_use}");
260 }
261
262 #[cfg(feature = "logging")]
263 info!(target: "stdout", "Update n_predict");
264
265 // update metadata n_predict
266 update_n_predict(chat_request, &mut metadata, avaible_completion_tokens)?;
267
268 #[cfg(feature = "logging")]
269 info!(target: "stdout", "Feed the prompt to the model");
270
271 // feed the prompt to the model
272 set_prompt(model_name.as_ref(), &prompt)?;
273
274 #[cfg(feature = "logging")]
275 info!(target: "stdout", "Compute chat completion.");
276
277 // compute
278 let res = compute(chat_request, model_name.as_ref(), tool_use);
279
280 #[cfg(feature = "logging")]
281 info!(target: "stdout", "End of the chat completion");
282
283 // reset the model metadata
284 reset_model_metadata(model_name.as_ref())?;
285
286 res
287}
288
289fn compute(
290 chat_request: &mut RequestOfModelResponse,
291 model_name: Option<&String>,
292 tool_use: bool,
293) -> Result<(ResponseObject, bool), LlamaCoreError> {
294 let chat_graphs = match CHAT_GRAPHS.get() {
295 Some(chat_graphs) => chat_graphs,
296 None => {
297 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
298
299 #[cfg(feature = "logging")]
300 error!(target: "stdout", "{}", &err_msg);
301
302 return Err(LlamaCoreError::Operation(err_msg.into()));
303 }
304 };
305
306 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
307 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
308
309 #[cfg(feature = "logging")]
310 error!(target: "stdout", "{}", &err_msg);
311
312 LlamaCoreError::Operation(err_msg)
313 })?;
314
315 match model_name {
316 Some(model_name) => match chat_graphs.contains_key(model_name) {
317 true => {
318 let graph = chat_graphs.get_mut(model_name).unwrap();
319 compute_by_graph(chat_request, graph, tool_use)
320 }
321 false => match chat_graphs.iter_mut().next() {
322 Some((_, graph)) => compute_by_graph(chat_request, graph, tool_use),
323 None => {
324 let err_msg = "There is no model available in the chat graphs.";
325
326 #[cfg(feature = "logging")]
327 error!(target: "stdout", "{}", &err_msg);
328
329 Err(LlamaCoreError::Operation(err_msg.into()))
330 }
331 },
332 },
333 None => match chat_graphs.iter_mut().next() {
334 Some((_, graph)) => compute_by_graph(chat_request, graph, tool_use),
335 None => {
336 let err_msg = "There is no model available in the chat graphs.";
337
338 #[cfg(feature = "logging")]
339 error!(target: "stdout", "{}", &err_msg);
340
341 Err(LlamaCoreError::Operation(err_msg.into()))
342 }
343 },
344 }
345}
346
347fn compute_by_graph(
348 chat_request: &mut RequestOfModelResponse,
349 graph: &mut Graph<GgmlMetadata>,
350 tool_use: bool,
351) -> Result<(ResponseObject, bool), LlamaCoreError> {
352 #[cfg(feature = "logging")]
353 info!(target: "stdout", "Compute chat completion by the model named {}.", graph.name());
354
355 match graph.compute() {
356 Ok(_) => {
357 // Retrieve the output.
358 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
359 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
360 let err_msg = format!(
361 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
362 );
363
364 #[cfg(feature = "logging")]
365 error!(target: "stdout", "{}", &err_msg);
366
367 LlamaCoreError::Operation(err_msg)
368 })?;
369
370 #[cfg(feature = "logging")]
371 info!(target: "stdout", "raw generation: {output}");
372
373 // post-process
374 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
375 LlamaCoreError::Operation(format!("Failed to post-process the output. {e}"))
376 })?;
377
378 #[cfg(feature = "logging")]
379 info!(target: "stdout", "post-processed generation:\n{}", &message);
380
381 // retrieve the number of prompt and completion tokens
382 let token_info = get_token_info_by_graph(graph)?;
383
384 #[cfg(feature = "logging")]
385 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
386
387 let created = SystemTime::now()
388 .duration_since(std::time::UNIX_EPOCH)
389 .map_err(|e| {
390 let err_msg = format!("Failed to get the current time. Reason: {e}");
391
392 #[cfg(feature = "logging")]
393 error!(target: "stdout", "{}", &err_msg);
394
395 LlamaCoreError::Operation(err_msg)
396 })?;
397
398 match tool_use {
399 true => {
400 if graph.metadata.prompt_template != PromptTemplateType::MistralTool
401 && graph.metadata.prompt_template != PromptTemplateType::ChatMLTool
402 && graph.metadata.prompt_template != PromptTemplateType::GroqLlama3Tool
403 && graph.metadata.prompt_template != PromptTemplateType::Llama3Tool
404 && graph.metadata.prompt_template != PromptTemplateType::InternLM2Tool
405 && graph.metadata.prompt_template != PromptTemplateType::NemotronTool
406 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV32
407 && graph.metadata.prompt_template != PromptTemplateType::FunctionaryV31
408 && graph.metadata.prompt_template != PromptTemplateType::MistralSmallTool
409 && graph.metadata.prompt_template != PromptTemplateType::Llama4Chat
410 && graph.metadata.prompt_template != PromptTemplateType::Qwen3NoThink
411 && graph.metadata.prompt_template != PromptTemplateType::Smol3NoThink
412 && graph.metadata.prompt_template != PromptTemplateType::Gemma3
413 && graph.metadata.prompt_template != PromptTemplateType::GptOss
414 && graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent
415 && graph.metadata.prompt_template != PromptTemplateType::SeedOssNoThink
416 && graph.metadata.prompt_template != PromptTemplateType::SeedOssThink
417 {
418 let err_msg = format!("Unsupported prompt template: {}. The tool use is only supported for 'mistral-tool', 'chatml-tool', 'groq-llama3-tool', 'llama-3-tool', 'internlm-2-tool', 'nemotron-tool', 'functionary-31', 'functionary-32', 'mistral-small-tool', 'llama-4-chat', 'qwen3-no-think', 'smol-3-no-think', 'gemma-3', 'gpt-oss', 'qwen3-agent', 'seed-oss-no-think', and 'seed-oss-think' prompt templates.", graph.metadata.prompt_template);
419
420 #[cfg(feature = "logging")]
421 error!(target: "stdout", "{}", &err_msg);
422
423 return Err(LlamaCoreError::Operation(err_msg));
424 }
425
426 // let parsed_result = parse_tool_calls(&message, graph.metadata.prompt_template)?;
427
428 // let (finish_reason, content, include_tool_calls) =
429 // if parsed_result.tool_calls.is_empty() {
430 // (FinishReason::stop, Some(parsed_result.raw.clone()), false)
431 // } else if graph.metadata.prompt_template != PromptTemplateType::Qwen3Agent {
432 // (
433 // FinishReason::tool_calls,
434 // Some(parsed_result.raw.clone()),
435 // true,
436 // )
437 // } else {
438 // (
439 // FinishReason::tool_calls,
440 // parsed_result.content.clone(),
441 // true,
442 // )
443 // };
444
445 // let res = ChatCompletionObject {
446 // id: id.into(),
447 // object: String::from("chat.completion"),
448 // created: created.as_secs(),
449 // model: graph.name().to_owned(),
450 // choices: vec![ChatCompletionObjectChoice {
451 // index: 0,
452 // message: ChatCompletionObjectMessage {
453 // role: ChatCompletionRole::Assistant,
454 // content,
455 // tool_calls: parsed_result.tool_calls,
456 // function_call: None,
457 // },
458 // finish_reason,
459 // logprobs: None,
460 // }],
461 // usage: Usage {
462 // prompt_tokens: token_info.prompt_tokens,
463 // completion_tokens: token_info.completion_tokens,
464 // total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
465 // },
466 // };
467
468 // // create ChatCompletionResponse
469 // Ok((res, include_tool_calls))
470
471 unimplemented!("Return ResponseObject with tool calls")
472 }
473 false => {
474 let message = ResponseOutputItemOutputMessageContent::OutputText {
475 annotations: vec![],
476 text: message.to_string(),
477 ty: "output_text".to_string(),
478 logprobs: None,
479 };
480
481 let output_message = ResponseOutputItem::OutputMessage {
482 content: vec![message],
483 id: "msg_67ccd3acc8d48190a77525dc6de64b4104becb25c45c1d41".to_string(),
484 role: ChatCompletionRole::Assistant.to_string(),
485 status: "completed".to_string(),
486 ty: "message".to_string(),
487 };
488
489 let temperature = match &chat_request.temperature {
490 Some(t) => *t,
491 None => graph.metadata.temperature,
492 };
493
494 let top_p = match &chat_request.top_p {
495 Some(t) => *t,
496 None => graph.metadata.top_p,
497 };
498
499 let res = ResponseObject {
500 background: false,
501 conversation: None,
502 created_at: created.as_secs(),
503 error: None,
504 id: gen_response_id(),
505 incomplete_details: None,
506 instructions: None,
507 max_output_tokens: None,
508 max_tool_calls: None,
509 metadata: HashMap::new(),
510 model: graph.name().to_owned(),
511 object: "response".to_string(),
512 output: vec![output_message],
513 parallel_tool_calls: true,
514 previous_response_id: None,
515 safety_identifier: None,
516 status: "completed".to_string(),
517 temperature,
518 tool_choice: chat_request.tool_choice.clone(),
519 tools: chat_request.tools.clone(),
520 top_p,
521 truncation: None,
522 usage: Usage {
523 input_tokens: token_info.prompt_tokens,
524 input_tokens_details: InputTokensDetails { cached_tokens: 0 },
525 output_tokens: token_info.completion_tokens,
526 output_tokens_details: OutputTokensDetails {
527 reasoning_tokens: 0,
528 },
529 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
530 },
531 };
532
533 Ok((res, false))
534 }
535 }
536 }
537 Err(wasmedge_wasi_nn::Error::BackendError(wasmedge_wasi_nn::BackendError::ContextFull)) => {
538 // Retrieve the output.
539 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
540 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
541 let err_msg = format!(
542 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
543 );
544
545 #[cfg(feature = "logging")]
546 error!(target: "stdout", "{}", &err_msg);
547
548 LlamaCoreError::Operation(err_msg)
549 })?;
550
551 // post-process
552 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
553 let err_msg = format!("Failed to post-process the output. {e}");
554
555 #[cfg(feature = "logging")]
556 error!(target: "stdout", "{}", &err_msg);
557
558 LlamaCoreError::Operation(err_msg)
559 })?;
560
561 // retrieve the number of prompt and completion tokens
562 let token_info = get_token_info_by_graph(graph)?;
563
564 #[cfg(feature = "logging")]
565 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
566
567 let created = SystemTime::now()
568 .duration_since(std::time::UNIX_EPOCH)
569 .map_err(|e| {
570 let err_msg = format!("Failed to get the current time. Reason: {e}");
571
572 #[cfg(feature = "logging")]
573 error!(target: "stdout", "{}", &err_msg);
574
575 LlamaCoreError::Operation(err_msg)
576 })?;
577
578 let message = ResponseOutputItemOutputMessageContent::OutputText {
579 annotations: vec![],
580 text: message.to_string(),
581 ty: "output_text".to_string(),
582 logprobs: None,
583 };
584
585 let output_message = ResponseOutputItem::OutputMessage {
586 content: vec![message],
587 id: "msg_67ccd3acc8d48190a77525dc6de64b4104becb25c45c1d41".to_string(),
588 role: ChatCompletionRole::Assistant.to_string(),
589 status: "completed".to_string(),
590 ty: "message".to_string(),
591 };
592
593 let temperature = match &chat_request.temperature {
594 Some(t) => *t,
595 None => graph.metadata.temperature,
596 };
597
598 let top_p = match &chat_request.top_p {
599 Some(t) => *t,
600 None => graph.metadata.top_p,
601 };
602
603 let res = ResponseObject {
604 background: false,
605 conversation: None,
606 created_at: created.as_secs(),
607 error: None,
608 id: gen_response_id(),
609 incomplete_details: None,
610 instructions: None,
611 max_output_tokens: None,
612 max_tool_calls: None,
613 metadata: HashMap::new(),
614 model: graph.name().to_owned(),
615 object: "response".to_string(),
616 output: vec![output_message],
617 parallel_tool_calls: true,
618 previous_response_id: None,
619 safety_identifier: None,
620 status: "completed".to_string(),
621 temperature,
622 tool_choice: chat_request.tool_choice.clone(),
623 tools: chat_request.tools.clone(),
624 top_p,
625 truncation: None,
626 usage: Usage {
627 input_tokens: token_info.prompt_tokens,
628 input_tokens_details: InputTokensDetails { cached_tokens: 0 },
629 output_tokens: token_info.completion_tokens,
630 output_tokens_details: OutputTokensDetails {
631 reasoning_tokens: 0,
632 },
633 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
634 },
635 };
636
637 Ok((res, false))
638 }
639 Err(wasmedge_wasi_nn::Error::BackendError(
640 wasmedge_wasi_nn::BackendError::PromptTooLong,
641 )) => {
642 #[cfg(feature = "logging")]
643 warn!(target: "stdout", "The prompt is too long. Please reduce the length of your input and try again.");
644
645 // Retrieve the output.
646 let output_buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
647 let output = std::str::from_utf8(&output_buffer[..]).map_err(|e| {
648 let err_msg = format!(
649 "Failed to decode the buffer of the inference result to a utf-8 string. {e}"
650 );
651
652 #[cfg(feature = "logging")]
653 error!(target: "stdout", "{}", &err_msg);
654
655 LlamaCoreError::Operation(err_msg)
656 })?;
657
658 // post-process
659 let message = post_process(output, &graph.metadata.prompt_template).map_err(|e| {
660 let err_msg = format!("Failed to post-process the output. {e}");
661
662 #[cfg(feature = "logging")]
663 error!(target: "stdout", "{}", &err_msg);
664
665 LlamaCoreError::Operation(err_msg)
666 })?;
667
668 // retrieve the number of prompt and completion token
669 let token_info = get_token_info_by_graph(graph)?;
670
671 #[cfg(feature = "logging")]
672 info!(target: "stdout", "prompt tokens: {}, completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
673
674 let created = SystemTime::now()
675 .duration_since(std::time::UNIX_EPOCH)
676 .map_err(|e| {
677 let err_msg = format!("Failed to get the current time. Reason: {e}");
678
679 #[cfg(feature = "logging")]
680 error!(target: "stdout", "{}", &err_msg);
681
682 LlamaCoreError::Operation(err_msg)
683 })?;
684
685 let message = ResponseOutputItemOutputMessageContent::OutputText {
686 annotations: vec![],
687 text: message.to_string(),
688 ty: "output_text".to_string(),
689 logprobs: None,
690 };
691
692 let output_message = ResponseOutputItem::OutputMessage {
693 content: vec![message],
694 id: "msg_67ccd3acc8d48190a77525dc6de64b4104becb25c45c1d41".to_string(),
695 role: ChatCompletionRole::Assistant.to_string(),
696 status: "completed".to_string(),
697 ty: "message".to_string(),
698 };
699
700 let temperature = match &chat_request.temperature {
701 Some(t) => *t,
702 None => graph.metadata.temperature,
703 };
704
705 let top_p = match &chat_request.top_p {
706 Some(t) => *t,
707 None => graph.metadata.top_p,
708 };
709
710 let res = ResponseObject {
711 background: false,
712 conversation: None,
713 created_at: created.as_secs(),
714 error: None,
715 id: gen_response_id(),
716 incomplete_details: None,
717 instructions: None,
718 max_output_tokens: None,
719 max_tool_calls: None,
720 metadata: HashMap::new(),
721 model: graph.name().to_owned(),
722 object: "response".to_string(),
723 output: vec![output_message],
724 parallel_tool_calls: true,
725 previous_response_id: None,
726 safety_identifier: None,
727 status: "completed".to_string(),
728 temperature,
729 tool_choice: chat_request.tool_choice.clone(),
730 tools: chat_request.tools.clone(),
731 top_p,
732 truncation: None,
733 usage: Usage {
734 input_tokens: token_info.prompt_tokens,
735 input_tokens_details: InputTokensDetails { cached_tokens: 0 },
736 output_tokens: token_info.completion_tokens,
737 output_tokens_details: OutputTokensDetails {
738 reasoning_tokens: 0,
739 },
740 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
741 },
742 };
743
744 Ok((res, false))
745 }
746 Err(e) => {
747 let err_msg = format!("Failed to compute the chat completion. Reason: {e}");
748
749 #[cfg(feature = "logging")]
750 error!(target: "stdout", "{}", &err_msg);
751
752 Err(LlamaCoreError::Backend(BackendError::Compute(err_msg)))
753 }
754 }
755}
756
757// fn parse_tool_calls(
758// input: &str,
759// prompt_template: PromptTemplateType,
760// ) -> Result<ParseResult, LlamaCoreError> {
761// match prompt_template {
762// PromptTemplateType::MistralTool => match regex::Regex::new(r"\[\{.*?\}\]") {
763// Ok(re) => {
764// let mut values: Vec<serde_json::Value> = vec![];
765// for cap in re.captures_iter(input) {
766// let matched = &cap[0];
767
768// #[cfg(feature = "logging")]
769// info!(target: "stdout", "captured: {matched}");
770
771// match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
772// Ok(group) => values.extend(group),
773// Err(e) => {
774// let err_msg =
775// format!("Failed to deserialize generated tool calls. Reason: {e}");
776
777// #[cfg(feature = "logging")]
778// error!(target: "stdout", "{}", &err_msg);
779
780// return Err(LlamaCoreError::Operation(err_msg));
781// }
782// }
783// }
784
785// let mut tool_calls: Vec<ToolCall> = vec![];
786// for value in values.iter() {
787// let name = match value.get("name") {
788// Some(name) => name.to_string().replace("\"", ""),
789// None => {
790// let err_msg = format!(
791// "Failed to get the name of the function. Tool call: {value:?}"
792// );
793
794// #[cfg(feature = "logging")]
795// error!(target: "stdout", "{}", &err_msg);
796
797// return Err(LlamaCoreError::Operation(err_msg));
798// }
799// };
800
801// let arguments = match value.get("arguments") {
802// Some(arguments) => arguments.to_string(),
803// None => {
804// let err_msg = format!(
805// "Failed to get the arguments of the function. Tool call: {value:?}"
806// );
807
808// #[cfg(feature = "logging")]
809// error!(target: "stdout", "{}", &err_msg);
810
811// return Err(LlamaCoreError::Operation(err_msg));
812// }
813// };
814
815// let function = Function { name, arguments };
816
817// let tool_call = ToolCall {
818// id: "call_abc123".to_string(),
819// ty: "function".to_string(),
820// function,
821// };
822
823// tool_calls.push(tool_call);
824// }
825
826// let parsed = ParseResult {
827// raw: input.to_owned(),
828// content: None,
829// tool_calls,
830// };
831
832// #[cfg(feature = "logging")]
833// info!(target: "stdout", "parsed result: {parsed:?}");
834
835// Ok(parsed)
836// }
837// Err(e) => {
838// let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
839
840// #[cfg(feature = "logging")]
841// error!(target: "stdout", "{}", &err_msg);
842
843// Err(LlamaCoreError::Operation(err_msg))
844// }
845// },
846// PromptTemplateType::ChatMLTool => {
847// match regex::Regex::new(r"<tool_call>(.*?)</tool_call>") {
848// Ok(re) => {
849// let mut values: Vec<serde_json::Value> = vec![];
850// for cap in re.captures_iter(input) {
851// let matched = cap[1].replace("\\n", ""); // Remove "\\n" from the captured group
852
853// #[cfg(feature = "logging")]
854// info!(target: "stdout", "captured: {}", &matched);
855
856// match serde_json::from_str::<serde_json::Value>(&matched) {
857// Ok(value) => values.push(value),
858// Err(e) => {
859// let err_msg = format!(
860// "Failed to deserialize generated tool calls. Reason: {e}"
861// );
862
863// #[cfg(feature = "logging")]
864// error!(target: "stdout", "{}", &err_msg);
865
866// return Err(LlamaCoreError::Operation(err_msg));
867// }
868// }
869// }
870
871// let mut tool_calls: Vec<ToolCall> = vec![];
872// for value in values.iter() {
873// let name = match value.get("name") {
874// Some(name) => name.to_string().replace("\"", ""),
875// None => {
876// let err_msg = format!(
877// "Failed to get the name of the function. Tool call: {value:?}"
878// );
879
880// #[cfg(feature = "logging")]
881// error!(target: "stdout", "{}", &err_msg);
882
883// return Err(LlamaCoreError::Operation(err_msg));
884// }
885// };
886
887// let arguments = match value.get("arguments") {
888// Some(arguments) => arguments.to_string(),
889// None => {
890// let err_msg = format!(
891// "Failed to get the arguments of the function. Tool call: {value:?}"
892// );
893
894// #[cfg(feature = "logging")]
895// error!(target: "stdout", "{}", &err_msg);
896
897// return Err(LlamaCoreError::Operation(err_msg));
898// }
899// };
900
901// let function = Function { name, arguments };
902
903// let tool_call = ToolCall {
904// id: "call_abc123".to_string(),
905// ty: "function".to_string(),
906// function,
907// };
908
909// tool_calls.push(tool_call);
910// }
911
912// let parsed = ParseResult {
913// raw: input.to_owned(),
914// content: None,
915// tool_calls,
916// };
917
918// #[cfg(feature = "logging")]
919// info!(target: "stdout", "parsed result: {parsed:?}");
920
921// Ok(parsed)
922// }
923// Err(e) => {
924// let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
925
926// #[cfg(feature = "logging")]
927// error!(target: "stdout", "{}", &err_msg);
928
929// Err(LlamaCoreError::Operation(err_msg))
930// }
931// }
932// }
933// PromptTemplateType::GroqLlama3Tool => {
934// #[cfg(feature = "logging")]
935// info!(target: "stdout", "raw input: {input}");
936
937// match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
938// Ok(re) => {
939// let mut values: Vec<serde_json::Value> = vec![];
940// for cap in re.captures_iter(input) {
941// let matched = cap[1].trim();
942
943// #[cfg(feature = "logging")]
944// info!(target: "stdout", "captured: {matched}");
945
946// match serde_json::from_str::<serde_json::Value>(matched) {
947// Ok(value) => values.push(value),
948// Err(e) => {
949// let err_msg = format!(
950// "Failed to deserialize generated tool calls. Reason: {e}"
951// );
952
953// #[cfg(feature = "logging")]
954// error!(target: "stdout", "{}", &err_msg);
955
956// return Err(LlamaCoreError::Operation(err_msg));
957// }
958// }
959// }
960
961// let mut tool_calls: Vec<ToolCall> = vec![];
962// for value in values.iter() {
963// let name = match value.get("name") {
964// Some(name) => name.to_string().replace("\"", ""),
965// None => {
966// let err_msg = format!(
967// "Failed to get the name of the function. Tool call: {value:?}"
968// );
969
970// #[cfg(feature = "logging")]
971// error!(target: "stdout", "{}", &err_msg);
972
973// return Err(LlamaCoreError::Operation(err_msg));
974// }
975// };
976
977// let arguments = match value.get("arguments") {
978// Some(arguments) => {
979// if arguments.is_string() {
980// arguments.as_str().unwrap().to_string()
981// } else if arguments.is_object() {
982// let map = arguments.as_object().unwrap();
983
984// #[cfg(feature = "logging")]
985// info!(target: "stdout", "func arguments: {map:?}");
986
987// serde_json::to_string(map).unwrap()
988// } else {
989// serde_json::to_string(arguments).unwrap()
990// }
991// }
992// None => {
993// let err_msg = format!(
994// "Failed to get the arguments of the function. Tool call: {value:?}"
995// );
996
997// #[cfg(feature = "logging")]
998// error!(target: "stdout", "{}", &err_msg);
999
1000// return Err(LlamaCoreError::Operation(err_msg));
1001// }
1002// };
1003
1004// let function = Function { name, arguments };
1005
1006// let tool_call = ToolCall {
1007// id: "call_abc123".to_string(),
1008// ty: "function".to_string(),
1009// function,
1010// };
1011
1012// tool_calls.push(tool_call);
1013// }
1014
1015// let parsed = if tool_calls.is_empty() {
1016// ParseResult {
1017// raw: input.to_owned(),
1018// content: Some(input.to_owned()),
1019// tool_calls: vec![],
1020// }
1021// } else {
1022// ParseResult {
1023// raw: input.to_owned(),
1024// content: None,
1025// tool_calls,
1026// }
1027// };
1028
1029// #[cfg(feature = "logging")]
1030// info!(target: "stdout", "parsed result: {parsed:?}");
1031
1032// Ok(parsed)
1033// }
1034// Err(e) => {
1035// let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1036
1037// #[cfg(feature = "logging")]
1038// error!(target: "stdout", "{}", &err_msg);
1039
1040// Err(LlamaCoreError::Operation(err_msg))
1041// }
1042// }
1043// }
1044// PromptTemplateType::Llama3Tool => {
1045// #[cfg(feature = "logging")]
1046// info!(target: "stdout", "raw input: {input}");
1047
1048// let re = match regex::Regex::new(r"^\{(.|\r|\n)*\}$") {
1049// Ok(re) => re,
1050// Err(e) => {
1051// let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1052
1053// #[cfg(feature = "logging")]
1054// error!(target: "stdout", "{}", &err_msg);
1055
1056// return Err(LlamaCoreError::Operation(err_msg));
1057// }
1058// };
1059
1060// if re.is_match(input) {
1061// match serde_json::from_str::<serde_json::Value>(input) {
1062// Ok(value) => {
1063// let values: Vec<serde_json::Value> = vec![value];
1064
1065// let mut tool_calls: Vec<ToolCall> = vec![];
1066// for value in values.iter() {
1067// let name = match value.get("name") {
1068// Some(name) => name.to_string().replace("\"", ""),
1069// None => {
1070// let err_msg = format!(
1071// "Failed to get the name of the function. Tool call: {value:?}"
1072// );
1073
1074// #[cfg(feature = "logging")]
1075// error!(target: "stdout", "{}", &err_msg);
1076
1077// return Err(LlamaCoreError::Operation(err_msg));
1078// }
1079// };
1080
1081// let arguments = match value.get("parameters") {
1082// Some(arguments) => arguments.to_string(),
1083// None => {
1084// let err_msg = format!(
1085// "Failed to get the arguments of the function. Tool call: {value:?}"
1086// );
1087
1088// #[cfg(feature = "logging")]
1089// error!(target: "stdout", "{}", &err_msg);
1090
1091// return Err(LlamaCoreError::Operation(err_msg));
1092// }
1093// };
1094
1095// let function = Function { name, arguments };
1096
1097// let tool_call = ToolCall {
1098// id: "call_abc123".to_string(),
1099// ty: "function".to_string(),
1100// function,
1101// };
1102
1103// tool_calls.push(tool_call);
1104// }
1105
1106// let parsed = ParseResult {
1107// raw: input.to_owned(),
1108// content: None,
1109// tool_calls,
1110// };
1111
1112// #[cfg(feature = "logging")]
1113// info!(target: "stdout", "parsed result: {parsed:?}");
1114
1115// Ok(parsed)
1116// }
1117// Err(e) => {
1118// let err_msg =
1119// format!("Failed to deserialize generated tool calls. Reason: {e}");
1120
1121// #[cfg(feature = "logging")]
1122// error!(target: "stdout", "{}", &err_msg);
1123
1124// Err(LlamaCoreError::Operation(err_msg))
1125// }
1126// }
1127// } else {
1128// let parsed = ParseResult {
1129// raw: input.to_owned(),
1130// content: None,
1131// tool_calls: vec![],
1132// };
1133
1134// #[cfg(feature = "logging")]
1135// info!(target: "stdout", "parsed result: {parsed:?}");
1136
1137// Ok(parsed)
1138// }
1139// }
1140// PromptTemplateType::InternLM2Tool => {
1141// #[cfg(feature = "logging")]
1142// info!(target: "stdout", "raw input: {input}");
1143
1144// let blocks: Vec<&str> = input.trim().split("<|action_start|><|plugin|>").collect();
1145
1146// #[cfg(feature = "logging")]
1147// info!(target: "stdout", "blocks: {blocks:?}");
1148
1149// let mut tool_calls: Vec<ToolCall> = vec![];
1150// let mut content = String::new();
1151// for block in blocks {
1152// let block = block.trim();
1153// if !block.is_empty() {
1154// if block.ends_with("<|action_end|>") {
1155// let value = block.trim().trim_end_matches("<|action_end|>");
1156
1157// #[cfg(feature = "logging")]
1158// info!(target: "stdout", "tool call: {value}");
1159
1160// match serde_json::from_str::<serde_json::Value>(value) {
1161// Ok(value) => {
1162// let name = match value.get("name") {
1163// Some(name) => name.to_string().replace("\"", ""),
1164// None => {
1165// let err_msg = format!(
1166// "Failed to get the name of the function. Tool call: {value:?}"
1167// );
1168
1169// #[cfg(feature = "logging")]
1170// error!(target: "stdout", "{}", &err_msg);
1171
1172// return Err(LlamaCoreError::Operation(err_msg));
1173// }
1174// };
1175
1176// let arguments = match value.get("parameters") {
1177// Some(arguments) => arguments.to_string(),
1178// None => {
1179// let err_msg = format!(
1180// "Failed to get the arguments of the function. Tool call: {value:?}"
1181// );
1182
1183// #[cfg(feature = "logging")]
1184// error!(target: "stdout", "{}", &err_msg);
1185
1186// return Err(LlamaCoreError::Operation(err_msg));
1187// }
1188// };
1189
1190// let function = Function { name, arguments };
1191
1192// let tool_call = ToolCall {
1193// id: "call_abc123".to_string(),
1194// ty: "function".to_string(),
1195// function,
1196// };
1197
1198// tool_calls.push(tool_call);
1199// }
1200// Err(e) => {
1201// let err_msg = format!(
1202// "Failed to deserialize generated tool calls. Reason: {e}"
1203// );
1204
1205// #[cfg(feature = "logging")]
1206// error!(target: "stdout", "{}", &err_msg);
1207
1208// return Err(LlamaCoreError::Operation(err_msg));
1209// }
1210// }
1211// } else {
1212// content.push_str(block);
1213// content.push('\n');
1214// }
1215// }
1216// }
1217
1218// let parsed = match content.is_empty() {
1219// true => ParseResult {
1220// raw: input.to_owned(),
1221// content: None,
1222// tool_calls,
1223// },
1224// false => ParseResult {
1225// raw: input.to_owned(),
1226// content: Some(content.trim().to_owned()),
1227// tool_calls,
1228// },
1229// };
1230
1231// #[cfg(feature = "logging")]
1232// info!(target: "stdout", "parsed result: {parsed:?}");
1233
1234// Ok(parsed)
1235// }
1236// PromptTemplateType::NemotronTool => {
1237// #[cfg(feature = "logging")]
1238// info!(target: "stdout", "raw input: {input}");
1239
1240// match regex::Regex::new(r"(?s)<toolcall>\s*(.*?)\s*</toolcall>") {
1241// Ok(re) => {
1242// let mut values: Vec<serde_json::Value> = vec![];
1243// for cap in re.captures_iter(input) {
1244// #[cfg(feature = "logging")]
1245// info!(target: "stdout", "captured: {}", &cap[0]);
1246
1247// #[cfg(feature = "logging")]
1248// info!(target: "stdout", "extracted: {}", &cap[1]);
1249
1250// let matched = cap[1].trim();
1251
1252// #[cfg(feature = "logging")]
1253// info!(target: "stdout", "captured: {matched}");
1254
1255// match serde_json::from_str::<serde_json::Value>(matched) {
1256// Ok(value) => values.push(value),
1257// Err(e) => {
1258// let err_msg = format!(
1259// "Failed to deserialize generated tool calls. Reason: {e}"
1260// );
1261
1262// #[cfg(feature = "logging")]
1263// error!(target: "stdout", "{}", &err_msg);
1264
1265// return Err(LlamaCoreError::Operation(err_msg));
1266// }
1267// }
1268// }
1269
1270// let mut tool_calls: Vec<ToolCall> = vec![];
1271// for value in values.iter() {
1272// let name = match value.get("name") {
1273// Some(name) => name.to_string().replace("\"", ""),
1274// None => {
1275// let err_msg = format!(
1276// "Failed to get the name of the function. Tool call: {value:?}"
1277// );
1278
1279// #[cfg(feature = "logging")]
1280// error!(target: "stdout", "{}", &err_msg);
1281
1282// return Err(LlamaCoreError::Operation(err_msg));
1283// }
1284// };
1285
1286// let arguments = match value.get("arguments") {
1287// Some(arguments) => arguments.to_string(),
1288// None => {
1289// let err_msg = format!(
1290// "Failed to get the arguments of the function. Tool call: {value:?}"
1291// );
1292
1293// #[cfg(feature = "logging")]
1294// error!(target: "stdout", "{}", &err_msg);
1295
1296// return Err(LlamaCoreError::Operation(err_msg));
1297// }
1298// };
1299
1300// let function = Function { name, arguments };
1301
1302// let tool_call = ToolCall {
1303// id: "call_abc123".to_string(),
1304// ty: "function".to_string(),
1305// function,
1306// };
1307
1308// tool_calls.push(tool_call);
1309// }
1310
1311// let parsed = ParseResult {
1312// raw: input.to_owned(),
1313// content: None,
1314// tool_calls,
1315// };
1316
1317// #[cfg(feature = "logging")]
1318// info!(target: "stdout", "parsed result: {parsed:?}");
1319
1320// Ok(parsed)
1321// }
1322// Err(e) => {
1323// let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1324
1325// #[cfg(feature = "logging")]
1326// error!(target: "stdout", "{}", &err_msg);
1327
1328// Err(LlamaCoreError::Operation(err_msg))
1329// }
1330// }
1331// }
1332// PromptTemplateType::FunctionaryV32 => {
1333// #[cfg(feature = "logging")]
1334// info!(target: "stdout", "raw input: {input}");
1335
1336// match regex::Regex::new(r">>>\s*(\w+)\s*\{(.*)\}<\|eot_id\|>") {
1337// Ok(re) => {
1338// let mut tool_calls: Vec<ToolCall> = vec![];
1339// for cap in re.captures_iter(input) {
1340// #[cfg(feature = "logging")]
1341// info!(target: "stdout", "func_name: {}", &cap[1]);
1342
1343// #[cfg(feature = "logging")]
1344// info!(target: "stdout", "arguments: {}", &cap[2]);
1345
1346// let tool_call = ToolCall {
1347// id: "call_abc123".to_string(),
1348// ty: "function".to_string(),
1349// function: Function {
1350// name: cap[1].to_string(),
1351// arguments: cap[2].to_string(),
1352// },
1353// };
1354
1355// tool_calls.push(tool_call);
1356// }
1357
1358// let parsed = ParseResult {
1359// raw: input.to_owned(),
1360// content: None,
1361// tool_calls,
1362// };
1363
1364// #[cfg(feature = "logging")]
1365// info!(target: "stdout", "parsed result: {parsed:?}");
1366
1367// Ok(parsed)
1368// }
1369// Err(e) => {
1370// let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1371
1372// #[cfg(feature = "logging")]
1373// warn!(target: "stdout", "{}", &warn_msg);
1374
1375// Ok(ParseResult {
1376// raw: input.to_owned(),
1377// content: None,
1378// tool_calls: vec![],
1379// })
1380// }
1381// }
1382// }
1383// PromptTemplateType::FunctionaryV31 => {
1384// #[cfg(feature = "logging")]
1385// info!(target: "stdout", "raw input: {input}");
1386
1387// match regex::Regex::new(r"<function=(\w+)>\s*(\{.*?\})</function>") {
1388// Ok(re) => {
1389// let mut tool_calls: Vec<ToolCall> = vec![];
1390// for cap in re.captures_iter(input) {
1391// #[cfg(feature = "logging")]
1392// info!(target: "stdout", "func_name: {}", &cap[1]);
1393
1394// #[cfg(feature = "logging")]
1395// info!(target: "stdout", "arguments: {}", &cap[2]);
1396
1397// let tool_call = ToolCall {
1398// id: "call_abc123".to_string(),
1399// ty: "function".to_string(),
1400// function: Function {
1401// name: cap[1].to_string(),
1402// arguments: cap[2].to_string(),
1403// },
1404// };
1405
1406// tool_calls.push(tool_call);
1407// }
1408
1409// let parsed = ParseResult {
1410// raw: input.to_owned(),
1411// content: None,
1412// tool_calls,
1413// };
1414
1415// #[cfg(feature = "logging")]
1416// info!(target: "stdout", "parsed result: {parsed:?}");
1417
1418// Ok(parsed)
1419// }
1420// Err(e) => {
1421// let warn_msg = format!("Failed to create a regex pattern. Reason: {e}");
1422
1423// #[cfg(feature = "logging")]
1424// warn!(target: "stdout", "{}", &warn_msg);
1425
1426// Ok(ParseResult {
1427// raw: input.to_owned(),
1428// content: None,
1429// tool_calls: vec![],
1430// })
1431// }
1432// }
1433// }
1434// PromptTemplateType::MistralSmallTool => {
1435// #[cfg(feature = "logging")]
1436// info!(target: "stdout", "raw input: {input}");
1437
1438// match regex::Regex::new(r"\[TOOL_CALLS\]\s*(\[(.*?)\])") {
1439// Ok(re) => {
1440// let mut values: Vec<serde_json::Value> = vec![];
1441// if let Some(cap) = re.captures(input) {
1442// let matched = cap[1].trim();
1443
1444// #[cfg(feature = "logging")]
1445// info!(target: "stdout", "captured: {matched}");
1446
1447// match serde_json::from_str::<Vec<serde_json::Value>>(matched) {
1448// Ok(vals) => values = vals,
1449// Err(e) => {
1450// let err_msg = format!(
1451// "Failed to deserialize generated tool calls. Reason: {e}"
1452// );
1453
1454// #[cfg(feature = "logging")]
1455// error!(target: "stdout", "{}", &err_msg);
1456
1457// return Err(LlamaCoreError::Operation(err_msg));
1458// }
1459// }
1460// };
1461
1462// let mut tool_calls: Vec<ToolCall> = vec![];
1463// for value in values.iter() {
1464// if let Some(object_map) = value.as_object() {
1465// if object_map.contains_key("function") {
1466// let mut function = Function {
1467// name: String::new(),
1468// arguments: String::new(),
1469// };
1470
1471// let value = object_map.get("function").unwrap();
1472// let func_map = value.as_object().unwrap();
1473// if func_map.contains_key("name") {
1474// let func_name = func_map.get("name").unwrap().as_str().unwrap();
1475// println!("Function name: {func_name:?}");
1476
1477// function.name = func_name.to_string();
1478// }
1479// if func_map.contains_key("arguments") {
1480// let args = func_map.get("arguments").unwrap();
1481// let arguments = args.to_string();
1482// println!("Arguments: {arguments:?}");
1483
1484// function.arguments = arguments;
1485// }
1486
1487// let tool_call = ToolCall {
1488// id: "call_abc123".to_string(),
1489// ty: "function".to_string(),
1490// function,
1491// };
1492
1493// tool_calls.push(tool_call);
1494// } else if object_map.contains_key("name") {
1495// let mut function = Function {
1496// name: String::new(),
1497// arguments: String::new(),
1498// };
1499
1500// let name = object_map.get("name").unwrap().as_str().unwrap();
1501// println!("name: {name:?}");
1502// function.name = name.to_string();
1503
1504// if object_map.contains_key("arguments") {
1505// let args = object_map.get("arguments").unwrap();
1506// let arguments = args.to_string();
1507// println!("Arguments: {arguments:?}");
1508
1509// function.arguments = arguments;
1510// }
1511
1512// let tool_call = ToolCall {
1513// id: "call_abc123".to_string(),
1514// ty: "function".to_string(),
1515// function,
1516// };
1517
1518// tool_calls.push(tool_call);
1519// }
1520// }
1521// }
1522
1523// let parsed = ParseResult {
1524// raw: input.to_owned(),
1525// content: None,
1526// tool_calls,
1527// };
1528
1529// #[cfg(feature = "logging")]
1530// info!(target: "stdout", "parsed result: {parsed:?}");
1531
1532// Ok(parsed)
1533// }
1534// Err(e) => {
1535// let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1536
1537// #[cfg(feature = "logging")]
1538// error!(target: "stdout", "{}", &err_msg);
1539
1540// Err(LlamaCoreError::Operation(err_msg))
1541// }
1542// }
1543// }
1544// PromptTemplateType::Llama4Chat => {
1545// #[cfg(feature = "logging")]
1546// info!(target: "stdout", "raw input: {input:?}");
1547
1548// let mut tool_calls: Vec<ToolCall> = vec![];
1549// if let Ok(value) = serde_json::from_str::<serde_json::Value>(input) {
1550// match value.as_object() {
1551// Some(object_map) => {
1552// #[cfg(feature = "logging")]
1553// debug!(target: "stdout", "object_map: {object_map:?}");
1554
1555// // parse function name
1556// if object_map.contains_key("name") {
1557// let name = object_map.get("name").unwrap().as_str().unwrap();
1558
1559// #[cfg(feature = "logging")]
1560// debug!(target: "stdout", "name: {name:?}");
1561
1562// let mut function = Function {
1563// name: name.to_string(),
1564// arguments: String::new(),
1565// };
1566
1567// // parse function arguments
1568// if object_map.contains_key("parameters") {
1569// let args = object_map.get("parameters").unwrap();
1570// let arguments = args.to_string();
1571
1572// #[cfg(feature = "logging")]
1573// debug!(target: "stdout", "arguments: {:?}", &arguments);
1574
1575// function.arguments = arguments;
1576// }
1577
1578// tool_calls.push(ToolCall {
1579// id: "call_abc123".to_string(),
1580// ty: "function".to_string(),
1581// function,
1582// });
1583// } else {
1584// let err_msg = format!(
1585// "Failed to get the name of the function. raw input: {input:?}"
1586// );
1587
1588// #[cfg(feature = "logging")]
1589// error!(target: "stdout", "{}", &err_msg);
1590
1591// return Err(LlamaCoreError::Operation(err_msg));
1592// }
1593// }
1594// None => {
1595// let err_msg = format!("Failed to parse the JSON string. JSON: {input}");
1596
1597// #[cfg(feature = "logging")]
1598// error!(target: "stdout", "{}", &err_msg);
1599
1600// return Err(LlamaCoreError::Operation(err_msg));
1601// }
1602// }
1603// }
1604
1605// let parsed = ParseResult {
1606// raw: input.to_owned(),
1607// content: None,
1608// tool_calls,
1609// };
1610
1611// #[cfg(feature = "logging")]
1612// info!(target: "stdout", "parsed result: {parsed:?}");
1613
1614// Ok(parsed)
1615// }
1616// PromptTemplateType::Qwen3NoThink | PromptTemplateType::Smol3NoThink => {
1617// #[cfg(feature = "logging")]
1618// info!(target: "stdout", "raw input: {input:?}");
1619
1620// match regex::Regex::new(r"(?s)<tool_call>((.|\r|\n)*?)</tool_call>") {
1621// Ok(re) => {
1622// let mut values: Vec<serde_json::Value> = vec![];
1623// for cap in re.captures_iter(input) {
1624// let mut matched = cap[1].trim();
1625
1626// if matched.starts_with("\\n") {
1627// matched = matched.trim_start_matches("\\n");
1628// }
1629
1630// if matched.ends_with("\\n") {
1631// matched = matched.trim_end_matches("\\n");
1632// }
1633
1634// #[cfg(feature = "logging")]
1635// info!(target: "stdout", "captured: {matched:#?}");
1636
1637// if !matched.is_empty() {
1638// match serde_json::from_str::<serde_json::Value>(matched) {
1639// Ok(value) => values.push(value),
1640// Err(e) => {
1641// let err_msg = format!(
1642// "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
1643// );
1644
1645// #[cfg(feature = "logging")]
1646// error!(target: "stdout", "{}", &err_msg);
1647
1648// return Err(LlamaCoreError::Operation(err_msg));
1649// }
1650// }
1651// }
1652// }
1653
1654// let mut tool_calls: Vec<ToolCall> = vec![];
1655// for value in values.iter() {
1656// let name = match value.get("name") {
1657// Some(name) => name.to_string().replace("\"", ""),
1658// None => {
1659// let err_msg = format!(
1660// "Failed to get the name of the function. Tool call: {value:?}"
1661// );
1662
1663// #[cfg(feature = "logging")]
1664// error!(target: "stdout", "{}", &err_msg);
1665
1666// return Err(LlamaCoreError::Operation(err_msg));
1667// }
1668// };
1669
1670// let arguments = match value.get("arguments") {
1671// Some(arguments) => {
1672// if arguments.is_string() {
1673// arguments.as_str().unwrap().to_string()
1674// } else if arguments.is_object() {
1675// let map = arguments.as_object().unwrap();
1676
1677// #[cfg(feature = "logging")]
1678// info!(target: "stdout", "func arguments: {map:?}");
1679
1680// serde_json::to_string(map).unwrap()
1681// } else {
1682// serde_json::to_string(arguments).unwrap()
1683// }
1684// }
1685// None => {
1686// let err_msg = format!(
1687// "Failed to get the arguments of the function. Tool call: {value:?}"
1688// );
1689
1690// #[cfg(feature = "logging")]
1691// error!(target: "stdout", "{}", &err_msg);
1692
1693// return Err(LlamaCoreError::Operation(err_msg));
1694// }
1695// };
1696
1697// let function = Function { name, arguments };
1698
1699// let tool_call = ToolCall {
1700// id: "call_abc123".to_string(),
1701// ty: "function".to_string(),
1702// function,
1703// };
1704
1705// tool_calls.push(tool_call);
1706// }
1707
1708// let parsed = if tool_calls.is_empty() {
1709// ParseResult {
1710// raw: input.to_owned(),
1711// content: Some(input.to_owned()),
1712// tool_calls: vec![],
1713// }
1714// } else {
1715// ParseResult {
1716// raw: input.to_owned(),
1717// content: None,
1718// tool_calls,
1719// }
1720// };
1721
1722// #[cfg(feature = "logging")]
1723// info!(target: "stdout", "parsed result: {parsed:?}");
1724
1725// Ok(parsed)
1726// }
1727// Err(e) => {
1728// let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1729
1730// #[cfg(feature = "logging")]
1731// error!(target: "stdout", "{}", &err_msg);
1732
1733// Err(LlamaCoreError::Operation(err_msg))
1734// }
1735// }
1736// }
1737// PromptTemplateType::Gemma3 => {
1738// #[cfg(feature = "logging")]
1739// info!(target: "stdout", "raw input: {input:?}");
1740
1741// match regex::Regex::new(r"(?s)```json\s*(.*?)\s*```") {
1742// Ok(re) => {
1743// let mut values: Vec<serde_json::Value> = vec![];
1744// for cap in re.captures_iter(input) {
1745// let mut matched = cap[1].trim();
1746
1747// if matched.starts_with("\\n") {
1748// matched = matched.trim_start_matches("\\n");
1749// }
1750
1751// if matched.ends_with("\\n") {
1752// matched = matched.trim_end_matches("\\n");
1753// }
1754
1755// #[cfg(feature = "logging")]
1756// info!(target: "stdout", "captured: {matched:#?}");
1757
1758// if !matched.is_empty() {
1759// match serde_json::from_str::<serde_json::Value>(matched) {
1760// Ok(value) => values.push(value),
1761// Err(e) => {
1762// let err_msg = format!(
1763// "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
1764// );
1765
1766// #[cfg(feature = "logging")]
1767// error!(target: "stdout", "{}", &err_msg);
1768
1769// return Err(LlamaCoreError::Operation(err_msg));
1770// }
1771// }
1772// }
1773// }
1774
1775// let mut tool_calls: Vec<ToolCall> = vec![];
1776// for value in values.iter() {
1777// let name = match value.get("name") {
1778// Some(name) => name.to_string().replace("\"", ""),
1779// None => {
1780// let err_msg = format!(
1781// "Failed to get the name of the function. Tool call: {value:?}"
1782// );
1783
1784// #[cfg(feature = "logging")]
1785// error!(target: "stdout", "{}", &err_msg);
1786
1787// return Err(LlamaCoreError::Operation(err_msg));
1788// }
1789// };
1790
1791// let arguments = match value.get("arguments") {
1792// Some(arguments) => {
1793// if arguments.is_string() {
1794// arguments.as_str().unwrap().to_string()
1795// } else if arguments.is_object() {
1796// let map = arguments.as_object().unwrap();
1797
1798// #[cfg(feature = "logging")]
1799// info!(target: "stdout", "func arguments: {map:?}");
1800
1801// serde_json::to_string(map).unwrap()
1802// } else {
1803// serde_json::to_string(arguments).unwrap()
1804// }
1805// }
1806// None => {
1807// let err_msg = format!(
1808// "Failed to get the arguments of the function. Tool call: {value:?}"
1809// );
1810
1811// #[cfg(feature = "logging")]
1812// error!(target: "stdout", "{}", &err_msg);
1813
1814// return Err(LlamaCoreError::Operation(err_msg));
1815// }
1816// };
1817
1818// let function = Function { name, arguments };
1819
1820// let tool_call = ToolCall {
1821// id: "call_abc123".to_string(),
1822// ty: "function".to_string(),
1823// function,
1824// };
1825
1826// tool_calls.push(tool_call);
1827// }
1828
1829// let parsed = if tool_calls.is_empty() {
1830// ParseResult {
1831// raw: input.to_owned(),
1832// content: Some(input.to_owned()),
1833// tool_calls: vec![],
1834// }
1835// } else {
1836// ParseResult {
1837// raw: input.to_owned(),
1838// content: None,
1839// tool_calls,
1840// }
1841// };
1842
1843// #[cfg(feature = "logging")]
1844// info!(target: "stdout", "parsed result: {parsed:?}");
1845
1846// Ok(parsed)
1847// }
1848// Err(e) => {
1849// let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
1850
1851// #[cfg(feature = "logging")]
1852// error!(target: "stdout", "{}", &err_msg);
1853
1854// Err(LlamaCoreError::Operation(err_msg))
1855// }
1856// }
1857// }
1858// PromptTemplateType::GptOss => {
1859// #[cfg(feature = "logging")]
1860// info!(target: "stdout", "raw input: {input:?}");
1861
1862// // Match strings ending with: <|channel|>commentary to=functions.xxxxx <|constrain|>json<|message|>yyyyy<|call|>
1863// match regex::Regex::new(
1864// r"<\|channel\|>commentary to=functions\.([^<\s]+)\s*<\|constrain\|>json<\|message\|>([^<]*)<\|call\|>$",
1865// ) {
1866// Ok(re) => {
1867// if let Some(cap) = re.captures(input) {
1868// let function_name = cap[1].trim();
1869// let arguments = cap[2].trim();
1870
1871// #[cfg(feature = "logging")]
1872// info!(target: "stdout", "extracted function_name: {function_name}, arguments: {arguments}");
1873
1874// let function = Function {
1875// name: function_name.to_string(),
1876// arguments: arguments.to_string(),
1877// };
1878
1879// let tool_call = ToolCall {
1880// id: "call_abc123".to_string(),
1881// ty: "function".to_string(),
1882// function,
1883// };
1884
1885// let parsed = ParseResult {
1886// raw: input.to_owned(),
1887// content: None,
1888// tool_calls: vec![tool_call],
1889// };
1890
1891// #[cfg(feature = "logging")]
1892// info!(target: "stdout", "parsed result: {parsed:?}");
1893
1894// Ok(parsed)
1895// } else {
1896// match regex::Regex::new(r"(?s)```json\s*(.*?)\s*```") {
1897// Ok(re) => {
1898// let mut values: Vec<serde_json::Value> = vec![];
1899// for cap in re.captures_iter(input) {
1900// let mut matched = cap[1].trim();
1901
1902// if matched.starts_with("\\n") {
1903// matched = matched.trim_start_matches("\\n");
1904// }
1905
1906// if matched.ends_with("\\n") {
1907// matched = matched.trim_end_matches("\\n");
1908// }
1909
1910// #[cfg(feature = "logging")]
1911// info!(target: "stdout", "captured: {matched:#?}");
1912
1913// if !matched.is_empty() {
1914// match serde_json::from_str::<serde_json::Value>(matched) {
1915// Ok(value) => values.push(value),
1916// Err(e) => {
1917// let err_msg = format!(
1918// "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
1919// );
1920
1921// #[cfg(feature = "logging")]
1922// error!(target: "stdout", "{}", &err_msg);
1923
1924// return Err(LlamaCoreError::Operation(err_msg));
1925// }
1926// }
1927// }
1928// }
1929
1930// let mut tool_calls: Vec<ToolCall> = vec![];
1931// for value in values.iter() {
1932// let name = match value.get("name") {
1933// Some(name) => name.to_string().replace("\"", ""),
1934// None => {
1935// let err_msg = format!(
1936// "Failed to get the name of the function. Tool call: {value:?}"
1937// );
1938
1939// #[cfg(feature = "logging")]
1940// error!(target: "stdout", "{}", &err_msg);
1941
1942// return Err(LlamaCoreError::Operation(err_msg));
1943// }
1944// };
1945
1946// let arguments = match value.get("arguments") {
1947// Some(arguments) => {
1948// if arguments.is_string() {
1949// arguments.as_str().unwrap().to_string()
1950// } else if arguments.is_object() {
1951// let map = arguments.as_object().unwrap();
1952
1953// #[cfg(feature = "logging")]
1954// info!(target: "stdout", "func arguments: {map:?}");
1955
1956// serde_json::to_string(map).unwrap()
1957// } else {
1958// serde_json::to_string(arguments).unwrap()
1959// }
1960// }
1961// None => {
1962// let err_msg = format!(
1963// "Failed to get the arguments of the function. Tool call: {value:?}"
1964// );
1965
1966// #[cfg(feature = "logging")]
1967// error!(target: "stdout", "{}", &err_msg);
1968
1969// return Err(LlamaCoreError::Operation(err_msg));
1970// }
1971// };
1972
1973// let function = Function { name, arguments };
1974
1975// let tool_call = ToolCall {
1976// id: "call_abc123".to_string(),
1977// ty: "function".to_string(),
1978// function,
1979// };
1980
1981// tool_calls.push(tool_call);
1982// }
1983
1984// let parsed = if tool_calls.is_empty() {
1985// ParseResult {
1986// raw: input.to_owned(),
1987// content: Some(input.to_owned()),
1988// tool_calls: vec![],
1989// }
1990// } else {
1991// ParseResult {
1992// raw: input.to_owned(),
1993// content: Some(input.to_owned()),
1994// tool_calls,
1995// }
1996// };
1997
1998// #[cfg(feature = "logging")]
1999// info!(target: "stdout", "parsed result: {parsed:?}");
2000
2001// Ok(parsed)
2002// }
2003// Err(e) => {
2004// let err_msg =
2005// format!("Failed to create a regex pattern. Reason: {e}");
2006
2007// #[cfg(feature = "logging")]
2008// error!(target: "stdout", "{}", &err_msg);
2009
2010// Err(LlamaCoreError::Operation(err_msg))
2011// }
2012// }
2013// }
2014// }
2015// Err(e) => {
2016// let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2017
2018// #[cfg(feature = "logging")]
2019// error!(target: "stdout", "{}", &err_msg);
2020
2021// Err(LlamaCoreError::Operation(err_msg))
2022// }
2023// }
2024// }
2025// PromptTemplateType::Qwen3Agent => {
2026// #[cfg(feature = "logging")]
2027// info!(target: "stdout", "Raw input to tool call parser: {input:?}");
2028
2029// // detect <action> tags
2030// match regex::Regex::new(r"<action>(.*?)</action>")
2031// .unwrap()
2032// .captures(input)
2033// {
2034// Some(captures) => {
2035// let action = captures.get(1).unwrap().as_str();
2036
2037// #[cfg(feature = "logging")]
2038// info!(target: "stdout", "Action: {action}");
2039
2040// match serde_json::from_str::<serde_json::Value>(action) {
2041// Ok(value) => {
2042// let name = match value.get("name") {
2043// Some(name) => name.to_string().replace("\"", ""),
2044// None => {
2045// let err_msg = format!(
2046// "Failed to get the name of the function. Tool call: {value:?}"
2047// );
2048
2049// #[cfg(feature = "logging")]
2050// error!(target: "stdout", "{}", &err_msg);
2051
2052// return Err(LlamaCoreError::Operation(err_msg));
2053// }
2054// };
2055
2056// let arguments = match value.get("arguments") {
2057// Some(arguments) => {
2058// if arguments.is_string() {
2059// arguments.as_str().unwrap().to_string()
2060// } else if arguments.is_object() {
2061// let map = arguments.as_object().unwrap();
2062
2063// #[cfg(feature = "logging")]
2064// info!(target: "stdout", "func arguments: {map:?}");
2065
2066// serde_json::to_string(map).unwrap()
2067// } else {
2068// serde_json::to_string(arguments).unwrap()
2069// }
2070// }
2071// None => {
2072// let err_msg = format!(
2073// "Failed to get the arguments of the function. Tool call: {value:?}"
2074// );
2075
2076// #[cfg(feature = "logging")]
2077// error!(target: "stdout", "{}", &err_msg);
2078
2079// return Err(LlamaCoreError::Operation(err_msg));
2080// }
2081// };
2082
2083// let function = Function { name, arguments };
2084
2085// let tool_call = ToolCall {
2086// id: "call_abc123".to_string(),
2087// ty: "function".to_string(),
2088// function,
2089// };
2090
2091// Ok(ParseResult {
2092// raw: input.to_owned(),
2093// content: Some(input.to_owned()),
2094// tool_calls: vec![tool_call],
2095// })
2096// }
2097// Err(e) => {
2098// let err_msg = format!(
2099// "Failed to deserialize generated tool calls: {action:#?}. Reason: {e}"
2100// );
2101
2102// #[cfg(feature = "logging")]
2103// error!(target: "stdout", "{}", &err_msg);
2104
2105// Err(LlamaCoreError::Operation(err_msg))
2106// }
2107// }
2108// }
2109// None => match input.contains("<final_answer>") {
2110// true => Ok(ParseResult {
2111// raw: input.to_owned(),
2112// content: Some(input.to_owned()),
2113// tool_calls: vec![],
2114// }),
2115// false => {
2116// let content = format!("<final_answer>{}</final_answer>", input.trim());
2117
2118// Ok(ParseResult {
2119// raw: input.to_owned(),
2120// content: Some(content),
2121// tool_calls: vec![],
2122// })
2123// }
2124// },
2125// }
2126// }
2127// PromptTemplateType::SeedOssNoThink | PromptTemplateType::SeedOssThink => {
2128// #[cfg(feature = "logging")]
2129// info!(target: "stdout", "Raw input to tool call parser: {input:?}");
2130
2131// match regex::Regex::new(r"```json\n([\s\S]*?)\n") {
2132// Ok(re) => {
2133// let mut values: Vec<serde_json::Value> = vec![];
2134// for cap in re.captures_iter(input) {
2135// let mut matched = cap[1].trim();
2136
2137// if matched.starts_with("\\n") {
2138// matched = matched.trim_start_matches("\\n");
2139// }
2140
2141// if matched.ends_with("\\n") {
2142// matched = matched.trim_end_matches("\\n");
2143// }
2144
2145// #[cfg(feature = "logging")]
2146// info!(target: "stdout", "captured: {matched:#?}");
2147
2148// if !matched.is_empty() {
2149// match serde_json::from_str::<serde_json::Value>(matched) {
2150// Ok(value) => values.push(value),
2151// Err(e) => {
2152// let err_msg = format!(
2153// "Failed to deserialize generated tool calls: {matched:#?}. Reason: {e}"
2154// );
2155
2156// #[cfg(feature = "logging")]
2157// error!(target: "stdout", "{}", &err_msg);
2158
2159// return Err(LlamaCoreError::Operation(err_msg));
2160// }
2161// }
2162// }
2163// }
2164
2165// let mut tool_calls: Vec<ToolCall> = vec![];
2166// for value in values.iter() {
2167// let name = match value.get("name") {
2168// Some(name) => name.to_string().replace("\"", ""),
2169// None => {
2170// let err_msg = format!(
2171// "Failed to get the name of the function. Tool call: {value:?}"
2172// );
2173
2174// #[cfg(feature = "logging")]
2175// error!(target: "stdout", "{}", &err_msg);
2176
2177// return Err(LlamaCoreError::Operation(err_msg));
2178// }
2179// };
2180
2181// let arguments = match value.get("arguments") {
2182// Some(arguments) => {
2183// if arguments.is_string() {
2184// arguments.as_str().unwrap().to_string()
2185// } else if arguments.is_object() {
2186// let map = arguments.as_object().unwrap();
2187
2188// #[cfg(feature = "logging")]
2189// info!(target: "stdout", "func arguments: {map:?}");
2190
2191// serde_json::to_string(map).unwrap()
2192// } else {
2193// serde_json::to_string(arguments).unwrap()
2194// }
2195// }
2196// None => {
2197// let err_msg = format!(
2198// "Failed to get the arguments of the function. Tool call: {value:?}"
2199// );
2200
2201// #[cfg(feature = "logging")]
2202// error!(target: "stdout", "{}", &err_msg);
2203
2204// return Err(LlamaCoreError::Operation(err_msg));
2205// }
2206// };
2207
2208// let function = Function { name, arguments };
2209
2210// let tool_call = ToolCall {
2211// id: "call_abc123".to_string(),
2212// ty: "function".to_string(),
2213// function,
2214// };
2215
2216// tool_calls.push(tool_call);
2217// }
2218
2219// let parsed = if tool_calls.is_empty() {
2220// ParseResult {
2221// raw: input.to_owned(),
2222// content: Some(input.to_owned()),
2223// tool_calls: vec![],
2224// }
2225// } else {
2226// ParseResult {
2227// raw: input.to_owned(),
2228// content: None,
2229// tool_calls,
2230// }
2231// };
2232
2233// #[cfg(feature = "logging")]
2234// info!(target: "stdout", "parsed result: {parsed:?}");
2235
2236// Ok(parsed)
2237// }
2238// Err(e) => {
2239// let err_msg = format!("Failed to create a regex pattern. Reason: {e}");
2240
2241// #[cfg(feature = "logging")]
2242// error!(target: "stdout", "{}", &err_msg);
2243
2244// Err(LlamaCoreError::Operation(err_msg))
2245// }
2246// }
2247// }
2248// _ => {
2249// let err_msg = format!(
2250// "The tool use is only supported for prompt templates: {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, and {}.",
2251// PromptTemplateType::MistralTool,
2252// PromptTemplateType::ChatMLTool,
2253// PromptTemplateType::GroqLlama3Tool,
2254// PromptTemplateType::Llama3Tool,
2255// PromptTemplateType::InternLM2Tool,
2256// PromptTemplateType::NemotronTool,
2257// PromptTemplateType::FunctionaryV32,
2258// PromptTemplateType::MistralSmallTool,
2259// PromptTemplateType::Llama4Chat,
2260// PromptTemplateType::Qwen3NoThink,
2261// PromptTemplateType::Smol3NoThink,
2262// PromptTemplateType::Gemma3,
2263// PromptTemplateType::GptOss,
2264// PromptTemplateType::Qwen3Agent,
2265// PromptTemplateType::SeedOssNoThink,
2266// PromptTemplateType::SeedOssThink
2267// );
2268
2269// #[cfg(feature = "logging")]
2270// error!(target: "stdout", "{}", &err_msg);
2271
2272// Err(LlamaCoreError::Operation(err_msg))
2273// }
2274// }
2275// }
2276
2277fn check_model_metadata(
2278 chat_request: &RequestOfModelResponse,
2279) -> Result<GgmlMetadata, LlamaCoreError> {
2280 let mut should_update = false;
2281
2282 let mut metadata = get_model_metadata(chat_request.model.as_ref())?;
2283
2284 // TODO: check if necessary to update `image`
2285 // if metadata.prompt_template.is_image_supported() {
2286 // if let Some(ChatCompletionRequestMessage::User(user_message)) = chat_request.messages.last()
2287 // {
2288 // if let ChatCompletionUserMessageContent::Parts(parts) = user_message.content() {
2289 // for part in parts {
2290 // if let ContentPart::Image(image_part) = part {
2291 // let image = image_part.image();
2292
2293 // if image.is_url() {
2294 // let err_msg = "The image is provided in URL format. Only base64 format is supported.".to_string();
2295
2296 // #[cfg(feature = "logging")]
2297 // error!(target: "stdout", "{}", &err_msg);
2298
2299 // return Err(LlamaCoreError::Operation(err_msg));
2300 // } else {
2301 // #[cfg(feature = "logging")]
2302 // info!(target: "stdout", "The image is provided in base64 format.");
2303
2304 // // TODO: now only support a single image
2305
2306 // break;
2307 // }
2308 // }
2309 // }
2310 // }
2311 // }
2312 // }
2313
2314 // check if necessary to update temperature
2315 if let Some(temp) = chat_request.temperature {
2316 if metadata.temperature != temp {
2317 // update temperature
2318 metadata.temperature = temp;
2319
2320 if !should_update {
2321 should_update = true;
2322 }
2323 }
2324 }
2325
2326 // check if necessary to update top_p
2327 if let Some(top_p) = chat_request.top_p {
2328 if metadata.top_p != top_p {
2329 // update top_p
2330 metadata.top_p = top_p;
2331
2332 if !should_update {
2333 should_update = true;
2334 }
2335 }
2336 }
2337
2338 // check if the `embedding` option is disabled
2339 if metadata.embeddings {
2340 metadata.embeddings = false;
2341
2342 if !should_update {
2343 should_update = true;
2344 }
2345 }
2346
2347 if should_update {
2348 #[cfg(feature = "logging")]
2349 info!(target: "stdout", "Update the model metadata.");
2350
2351 // update the target graph with the new metadata
2352 update_model_metadata(chat_request.model.as_ref(), &metadata)?;
2353 }
2354
2355 Ok(metadata)
2356}
2357
2358fn update_n_predict(
2359 chat_request: &RequestOfModelResponse,
2360 metadata: &mut GgmlMetadata,
2361 available_completion_tokens: u64,
2362) -> Result<(), LlamaCoreError> {
2363 let mut should_update = false;
2364
2365 #[cfg(feature = "logging")]
2366 info!(target: "stdout", "n_predict: {}", metadata.n_predict);
2367
2368 // From high to low priority
2369 // 1. chat_request.max_completion_tokens
2370 // 2. available_completion_tokens
2371 // 3. n_predict
2372
2373 if let Some(max_output_tokens) = chat_request.max_output_tokens {
2374 if metadata.n_predict != max_output_tokens {
2375 #[cfg(feature = "logging")]
2376 info!(target: "stdout", "Update n_predict with max_output_tokens from {} to {}", metadata.n_predict, max_output_tokens);
2377
2378 metadata.n_predict = max_output_tokens;
2379
2380 if !should_update {
2381 should_update = true;
2382 }
2383 }
2384 }
2385
2386 // TODO: remove this condition after [Issue #3958 on WasmEdge](https://github.com/WasmEdge/WasmEdge/issues/3958) is fixed
2387 if metadata.n_predict == -2 {
2388 #[cfg(feature = "logging")]
2389 info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2390
2391 // update n_predict
2392 metadata.n_predict = available_completion_tokens as i32;
2393
2394 if !should_update {
2395 should_update = true;
2396 }
2397 }
2398
2399 if metadata.n_predict == -1
2400 || (metadata.n_predict > 0 && metadata.n_predict < available_completion_tokens as i32)
2401 || (metadata.n_predict < 0 && metadata.n_predict != -2)
2402 // TODO: remove this condition after [Issue #3958 on WasmEdge](https://github.com/WasmEdge/WasmEdge/issues/3958) is fixed
2403 {
2404 #[cfg(feature = "logging")]
2405 info!(target: "stdout", "Update n_predict with available_completion_tokens from {} to {}", metadata.n_predict, available_completion_tokens);
2406
2407 // update n_predict
2408 metadata.n_predict = available_completion_tokens as i32;
2409
2410 if !should_update {
2411 should_update = true;
2412 }
2413 }
2414
2415 if should_update {
2416 #[cfg(feature = "logging")]
2417 info!(target: "stdout", "Update the model metadata.");
2418
2419 // update the target graph with the new metadata
2420 update_model_metadata(chat_request.model.as_ref(), metadata)?;
2421 }
2422
2423 Ok(())
2424}
2425
2426/// Build post-processing for output based on template type
2427fn post_process(
2428 output: impl AsRef<str>,
2429 template_ty: &PromptTemplateType,
2430) -> Result<String, String> {
2431 let output = if *template_ty == PromptTemplateType::Baichuan2 {
2432 if output.as_ref().contains("用户:") {
2433 output.as_ref().trim_end_matches("用户:").trim().to_owned()
2434 } else {
2435 output.as_ref().trim().to_owned()
2436 }
2437 } else if *template_ty == PromptTemplateType::OpenChat {
2438 if output.as_ref().contains("<|end_of_turn|>") {
2439 output
2440 .as_ref()
2441 .trim_end_matches("<|end_of_turn|>")
2442 .trim()
2443 .to_owned()
2444 } else {
2445 output.as_ref().trim().to_owned()
2446 }
2447 } else if *template_ty == PromptTemplateType::GemmaInstruct
2448 || *template_ty == PromptTemplateType::Gemma3
2449 {
2450 let s = output.as_ref().trim();
2451 if s.ends_with("<end_of_turn>") {
2452 s.trim_end_matches("<end_of_turn>").trim().to_owned()
2453 } else {
2454 s.to_owned()
2455 }
2456 } else if *template_ty == PromptTemplateType::ChatML
2457 || *template_ty == PromptTemplateType::ChatMLTool
2458 || *template_ty == PromptTemplateType::InternLM2Tool
2459 || *template_ty == PromptTemplateType::MiniCPMV
2460 {
2461 let mut s = output.as_ref().trim();
2462 if s.ends_with("<|endoftext|>") {
2463 s = s.trim_end_matches("<|endoftext|>").trim();
2464 }
2465
2466 if s.starts_with(":") {
2467 s = s.trim_start_matches(":").trim();
2468 }
2469
2470 // handle Qwen3 empty think tags
2471 let x = {
2472 let pat = r#"<think>
2473
2474</think>
2475"#;
2476 if s.contains(pat) {
2477 let x = s.replace(pat, "");
2478 if x.starts_with("()") {
2479 x.trim_start_matches("()").to_owned()
2480 } else {
2481 x.to_owned()
2482 }
2483 } else {
2484 s.to_owned()
2485 }
2486 };
2487 s = x.trim();
2488
2489 if s.contains("<|im_start|>") && s.contains("<|im_end|>") {
2490 let idx_start = s.find("<|im_start|>").unwrap();
2491 let idx_end = s.find("<|im_end|>").unwrap();
2492
2493 match idx_start <= idx_end {
2494 true => s.split("<|im_start|>").collect::<Vec<_>>()[0]
2495 .trim()
2496 .to_owned(),
2497 false => s.split("<|im_end|>").collect::<Vec<_>>()[0]
2498 .trim()
2499 .to_owned(),
2500 }
2501 } else if s.contains("<|im_start|>") {
2502 s.split("<|im_start|>").collect::<Vec<_>>()[0]
2503 .trim()
2504 .to_owned()
2505 } else if s.contains("<|im_end|>") {
2506 let output = s.trim_end_matches("<|im_end|>").trim();
2507 if output.starts_with(": ") {
2508 output.trim_start_matches(": ").to_owned()
2509 } else {
2510 output.to_owned()
2511 }
2512 } else {
2513 s.to_owned()
2514 }
2515 } else if *template_ty == PromptTemplateType::Zephyr
2516 || *template_ty == PromptTemplateType::MistralLite
2517 || *template_ty == PromptTemplateType::MistralTool
2518 || *template_ty == PromptTemplateType::MistralInstruct
2519 || *template_ty == PromptTemplateType::MistralSmallChat
2520 || *template_ty == PromptTemplateType::MistralSmallTool
2521 || *template_ty == PromptTemplateType::BreezeInstruct
2522 {
2523 if output.as_ref().contains("</s><") {
2524 output.as_ref().trim_end_matches("</s><").trim().to_owned()
2525 } else if output.as_ref().contains("</s>") {
2526 output
2527 .as_ref()
2528 .strip_suffix("</s>")
2529 .unwrap()
2530 .trim()
2531 .to_owned()
2532 } else {
2533 output.as_ref().trim().to_owned()
2534 }
2535 } else if *template_ty == PromptTemplateType::DeepseekChat {
2536 if output.as_ref().contains("<|end_of_sentence|>") {
2537 output
2538 .as_ref()
2539 .trim_end_matches("<|end_of_sentence|>")
2540 .trim()
2541 .replace("<|end_of_sentence|>", " ")
2542 .trim()
2543 .to_owned()
2544 } else {
2545 output.as_ref().trim().to_owned()
2546 }
2547 } else if *template_ty == PromptTemplateType::HumanAssistant {
2548 if output.as_ref().contains("Human:") {
2549 output.as_ref().trim_end_matches("Human:").trim().to_owned()
2550 } else {
2551 output.as_ref().trim().to_owned()
2552 }
2553 } else if *template_ty == PromptTemplateType::SolarInstruct {
2554 let s = output.as_ref().trim();
2555
2556 if s.starts_with("### Answer") {
2557 let s = s.trim_start_matches("###").trim();
2558
2559 if s.starts_with("Answer:\n") {
2560 s.replace("Answer:\n", "Answer: ")
2561 } else {
2562 s.to_owned()
2563 }
2564 } else {
2565 s.to_owned()
2566 }
2567 } else if *template_ty == PromptTemplateType::Llama2Chat
2568 || *template_ty == PromptTemplateType::NemotronTool
2569 || *template_ty == PromptTemplateType::NemotronChat
2570 {
2571 let s = output.as_ref().trim();
2572 if s.ends_with("</s>") {
2573 s.trim_end_matches("</s>").trim().to_owned()
2574 } else {
2575 s.to_owned()
2576 }
2577 } else if *template_ty == PromptTemplateType::Llama3Chat
2578 || *template_ty == PromptTemplateType::GroqLlama3Tool
2579 || *template_ty == PromptTemplateType::Llama3Tool
2580 || *template_ty == PromptTemplateType::FunctionaryV32
2581 {
2582 let s = output.as_ref().trim();
2583 if s.ends_with("<|eot_id|>") {
2584 s.trim_end_matches("<|eot_id|>").trim().to_owned()
2585 } else {
2586 s.to_owned()
2587 }
2588 } else if *template_ty == PromptTemplateType::Phi3Chat {
2589 let s = output.as_ref().trim();
2590 if s.ends_with("<|end|>") {
2591 s.trim_end_matches("<|end|>").trim().to_owned()
2592 } else {
2593 s.to_owned()
2594 }
2595 } else if *template_ty == PromptTemplateType::Phi4Chat {
2596 let mut s = output.as_ref().trim();
2597
2598 if s.starts_with("think>") {
2599 s = s.trim_start_matches("think>").trim();
2600 }
2601
2602 if s.ends_with("<|im_end|>") {
2603 s.trim_end_matches("<|im_end|>").trim().to_owned()
2604 } else if s.ends_with("<|end|>") {
2605 s.trim_end_matches("<|end|>").trim().to_owned()
2606 } else {
2607 s.to_owned()
2608 }
2609 } else if *template_ty == PromptTemplateType::FunctionaryV31 {
2610 let mut s = output.as_ref().trim();
2611 if s.ends_with("<|eot_id|>") {
2612 s = s.trim_end_matches("<|eot_id|>").trim();
2613 }
2614 if s.ends_with("<|eom_id|>") {
2615 s = s.trim_end_matches("<|eom_id|>").trim();
2616 }
2617 s.to_owned()
2618 } else if *template_ty == PromptTemplateType::MoxinChat
2619 || *template_ty == PromptTemplateType::MoxinInstruct
2620 {
2621 let s = output.as_ref().trim();
2622 if s.ends_with("</s>") {
2623 s.trim_end_matches("</s>").trim().to_owned()
2624 } else if s.ends_with("[INST]") {
2625 s.trim_end_matches("[INST]").trim().to_owned()
2626 } else {
2627 s.to_owned()
2628 }
2629 } else if *template_ty == PromptTemplateType::Falcon3 {
2630 let s = output.as_ref().trim();
2631 if s.ends_with("<|endoftext|>") {
2632 s.trim_end_matches("<|endoftext|>").trim().to_owned()
2633 } else {
2634 s.to_owned()
2635 }
2636 } else if *template_ty == PromptTemplateType::Megrez {
2637 let s = output.as_ref().trim();
2638 if s.ends_with("<|turn_end|>") {
2639 s.trim_end_matches("<|turn_end|>").trim().to_owned()
2640 } else {
2641 s.to_owned()
2642 }
2643 } else if *template_ty == PromptTemplateType::Qwen2vl
2644 || *template_ty == PromptTemplateType::Qwen3NoThink
2645 || *template_ty == PromptTemplateType::ChatMLThink
2646 {
2647 let mut s = output.as_ref().trim();
2648
2649 if s.starts_with(":") {
2650 s = s.trim_start_matches(":").trim();
2651 }
2652
2653 if s.starts_with("</think>") {
2654 s = s.trim_start_matches("</think>").trim();
2655 }
2656
2657 if s.ends_with("<|im_end|>") {
2658 s.trim_end_matches("<|im_end|>").trim().to_owned()
2659 } else {
2660 s.to_owned()
2661 }
2662 } else if *template_ty == PromptTemplateType::VicunaLlava {
2663 let s = output.as_ref().trim();
2664 if s.ends_with("</s>") {
2665 s.trim_end_matches("</s>").trim().to_owned()
2666 } else {
2667 s.to_owned()
2668 }
2669 } else if *template_ty == PromptTemplateType::ExaoneDeepChat
2670 || *template_ty == PromptTemplateType::ExaoneChat
2671 {
2672 let mut s = output.as_ref().trim();
2673
2674 if s.ends_with("[|endofturn|]") {
2675 s = s.trim_end_matches("[|endofturn|]").trim();
2676 }
2677
2678 s.to_owned()
2679 } else if *template_ty == PromptTemplateType::Llama4Chat {
2680 let mut s = output.as_ref().trim();
2681
2682 if s.ends_with("<|eot|>") {
2683 s = s.trim_end_matches("<|eot|>").trim();
2684 }
2685
2686 s.to_owned()
2687 } else if *template_ty == PromptTemplateType::Smolvl {
2688 let mut s = output.as_ref().trim();
2689
2690 if s.starts_with(":") {
2691 s = s.trim_start_matches(":").trim();
2692 }
2693
2694 if s.ends_with("<end_of_utterance>") {
2695 s = s.trim_end_matches("<end_of_utterance>").trim();
2696 }
2697
2698 if s.contains("<end_of_utterance>:") {
2699 let parts = s.split("<end_of_utterance>:").collect::<Vec<_>>();
2700 parts.last().unwrap().trim().to_owned()
2701 } else {
2702 s.to_owned()
2703 }
2704 } else if *template_ty == PromptTemplateType::Smol3NoThink {
2705 let mut s = output.as_ref().trim();
2706
2707 if s.ends_with("<|im_end|>") {
2708 s = s.trim_end_matches("<|im_end|>").trim();
2709 }
2710
2711 let re = regex::Regex::new(r"(?s)^<think>.*?</think>\s*").unwrap();
2712 re.replace(s, "").to_string()
2713 } else if *template_ty == PromptTemplateType::GptOss {
2714 let s = output.as_ref().trim();
2715
2716 let re =
2717 regex::Regex::new(r"(?s).*<\|channel\|>final<\|message\|>(.*?)<\|return\|>$").unwrap();
2718
2719 if let Some(caps) = re.captures(s) {
2720 let extracted = &caps[1];
2721 extracted.to_owned()
2722 } else {
2723 s.to_owned()
2724 }
2725 } else if *template_ty == PromptTemplateType::Qwen3Agent {
2726 let mut s = output.as_ref().trim();
2727
2728 if s.starts_with(":") {
2729 s = s.trim_start_matches(":").trim();
2730 }
2731
2732 if s.starts_with("</think>") {
2733 s = s.trim_start_matches("</think>").trim();
2734 }
2735
2736 if s.ends_with("<|im_end|>") {
2737 s = s.trim_end_matches("<|im_end|>").trim();
2738 }
2739
2740 if s.contains("<final_answer>") && !s.contains("</final_answer>") {
2741 format!("{s}</final_answer>")
2742 } else {
2743 s.to_owned()
2744 }
2745 } else if *template_ty == PromptTemplateType::SeedOssNoThink {
2746 let s = output.as_ref().trim();
2747
2748 let re = regex::Regex::new(r"(?s)</seed:think>\s*(.*?)\s*<seed:eos>").unwrap();
2749
2750 if let Some(caps) = re.captures(s) {
2751 let extracted = &caps[1];
2752 extracted.to_owned()
2753 } else {
2754 s.to_owned()
2755 }
2756 } else {
2757 output.as_ref().trim().to_owned()
2758 };
2759
2760 Ok(output)
2761}
2762
2763/// Build the chat prompt from the chat messages.
2764///
2765/// # Arguments
2766///
2767/// * `model_name`: The name of the model.
2768///
2769/// * `chat_request`: The chat request.
2770///
2771/// # Returns
2772///
2773/// A tuple containing the prompt, the number of available tokens for completions, and a boolean indicating whether tools are used.
2774fn build_prompt(
2775 model_name: Option<&String>,
2776 chat_request: &mut RequestOfModelResponse,
2777) -> Result<(String, u64, bool), LlamaCoreError> {
2778 let metadata = get_model_metadata(model_name)?;
2779 let ctx_size = metadata.ctx_size as u64;
2780 let chat_prompt = ChatPrompt::from(metadata.prompt_template);
2781
2782 // compute max prompt tokens, which is 80% of the context size
2783 let max_prompt_tokens = ctx_size * 4 / 5;
2784
2785 // convert chat_request.input to a user message and append it to chat_request.messages
2786 let mut chat_completions_messages = to_chat_messages(chat_request.input.as_ref().unwrap())?;
2787
2788 #[cfg(feature = "logging")]
2789 debug!(target: "stdout", "converted chat messages: {chat_completions_messages:?}");
2790
2791 loop {
2792 // if chat_request.messages.is_empty() {
2793 if chat_request.input.is_none() {
2794 let err_msg = "The `input` field of the request is empty.";
2795
2796 #[cfg(feature = "logging")]
2797 error!(target: "stdout", "{err_msg}");
2798
2799 return Err(LlamaCoreError::Operation(err_msg.to_owned()));
2800 }
2801
2802 let (prompt, tool_use) = match &chat_request.tool_choice {
2803 ToolChoice::None => {
2804 match chat_prompt.build_with_tools(&mut chat_completions_messages, Some(&[])) {
2805 Ok(prompt) => (prompt, false),
2806 Err(e) => {
2807 let err_msg = format!("Fail to build chat prompts. Reason: {e}");
2808
2809 #[cfg(feature = "logging")]
2810 error!(target: "stdout", "{}", &err_msg);
2811
2812 return Err(LlamaCoreError::Operation(err_msg));
2813 }
2814 }
2815 }
2816 _ => todo!("Other tool choices are not supported yet."),
2817 };
2818 #[cfg(feature = "logging")]
2819 info!(target: "stdout", "Try to set prompt: {prompt}");
2820
2821 // set prompt
2822 set_prompt(model_name, &prompt)?;
2823
2824 // Retrieve the number of prompt tokens.
2825 let token_info = get_token_info_by_graph_name(model_name)?;
2826
2827 match token_info.prompt_tokens > max_prompt_tokens {
2828 true => {
2829 match chat_completions_messages[0].role() {
2830 ChatCompletionRole::System => {
2831 // corner case: context size is too small, `system -> user -> assistant -> tool` cannot be trimmed.
2832 if chat_completions_messages.len() == 4
2833 && chat_completions_messages[1].role() == ChatCompletionRole::User
2834 && chat_completions_messages[2].role() == ChatCompletionRole::Assistant
2835 && chat_completions_messages[3].role() == ChatCompletionRole::Tool
2836 {
2837 let err_msg = format!(
2838 "The number of prompt tokens ({}) is greater than the max prompt tokens ({}). Please increase the context size.",
2839 token_info.prompt_tokens, max_prompt_tokens
2840 );
2841
2842 #[cfg(feature = "logging")]
2843 error!(target: "stdout", "{}", &err_msg);
2844
2845 return Err(LlamaCoreError::Operation(err_msg));
2846 }
2847
2848 if chat_completions_messages.len() > 2 {
2849 #[cfg(feature = "logging")]
2850 info!(target: "stdout", "Prune chat history: current length {}", chat_completions_messages.len());
2851
2852 // remove user_1 if it exists
2853 // For example, `system -> user_1 -> ... -> user_2 -> ... -> user_latest` will be converted to `system -> ... -> user_2 -> ... -> user_latest`
2854 if chat_completions_messages[1].role() == ChatCompletionRole::User {
2855 let user_message = chat_completions_messages.remove(1);
2856
2857 #[cfg(feature = "logging")]
2858 info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
2859 }
2860
2861 // remove all messages until the message is of `user`
2862 // For example, `system -> ... -> user_2 -> ... -> user_latest` will be converted to `system -> user_2 -> ... -> user_latest`
2863 while chat_completions_messages[1].role() != ChatCompletionRole::User {
2864 let message = chat_completions_messages.remove(1);
2865
2866 #[cfg(feature = "logging")]
2867 info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
2868
2869 if chat_completions_messages.len() == 1 {
2870 let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
2871
2872 #[cfg(feature = "logging")]
2873 error!(target: "stdout", "{err_msg}");
2874
2875 return Err(LlamaCoreError::Operation(err_msg));
2876 }
2877 }
2878 } else if token_info.prompt_tokens > ctx_size {
2879 let err_msg = format!(
2880 "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
2881 token_info.prompt_tokens, ctx_size
2882 );
2883
2884 #[cfg(feature = "logging")]
2885 error!(target: "stdout", "{}", &err_msg);
2886
2887 return Err(LlamaCoreError::Operation(err_msg));
2888 } else {
2889 return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
2890 }
2891 }
2892 ChatCompletionRole::User => {
2893 // corner case: context size is too small, `user -> assistant -> tool` cannot be trimmed.
2894 if chat_completions_messages.len() == 3
2895 && chat_completions_messages[1].role() == ChatCompletionRole::User
2896 && chat_completions_messages[2].role() == ChatCompletionRole::Assistant
2897 && chat_completions_messages[3].role() == ChatCompletionRole::Tool
2898 {
2899 let err_msg = format!(
2900 "The number of prompt tokens ({}) is greater than the max prompt tokens ({}). Please increase the context size.",
2901 token_info.prompt_tokens, max_prompt_tokens
2902 );
2903
2904 #[cfg(feature = "logging")]
2905 error!(target: "stdout", "{}", &err_msg);
2906
2907 return Err(LlamaCoreError::Operation(err_msg));
2908 }
2909
2910 if chat_completions_messages.len() > 1 {
2911 // user_1 -> ... -> user_2 -> ... -> user_latest
2912
2913 // remove user_1 if it exists
2914 // For example, `user_1 -> ... -> user_2 -> ... -> user_latest` will be converted to `... -> user_2 -> ... -> user_latest`
2915 if chat_completions_messages[0].role() == ChatCompletionRole::User {
2916 let user_message = chat_completions_messages.remove(0);
2917
2918 #[cfg(feature = "logging")]
2919 info!(target: "stdout", "Remove a user message from the chat history: {user_message:?}");
2920 }
2921
2922 // remove all messages until the message is of `user`
2923 // For example, `... -> user_2 -> ... -> user_latest` will be converted to `user_2 -> ... -> user_latest`
2924 while chat_completions_messages[0].role() != ChatCompletionRole::User {
2925 let message = chat_completions_messages.remove(0);
2926
2927 #[cfg(feature = "logging")]
2928 info!(target: "stdout", "Remove a {} message from the chat history: {:?}", message.role(), message);
2929
2930 if chat_completions_messages.is_empty() {
2931 let err_msg = format!("The last message in the chat history should be a user message, but found a {} message.", message.role());
2932
2933 #[cfg(feature = "logging")]
2934 error!(target: "stdout", "{err_msg}");
2935
2936 return Err(LlamaCoreError::Operation(err_msg));
2937 }
2938 }
2939 } else if token_info.prompt_tokens > ctx_size {
2940 let err_msg = format!(
2941 "The number of prompt tokens ({}) is greater than the context size ({}). Please increase the context size, or simplify the input message.",
2942 token_info.prompt_tokens, ctx_size
2943 );
2944
2945 #[cfg(feature = "logging")]
2946 error!(target: "stdout", "{}", &err_msg);
2947
2948 return Err(LlamaCoreError::Operation(err_msg));
2949 } else {
2950 return Ok((prompt, ctx_size - token_info.prompt_tokens, tool_use));
2951 }
2952 }
2953 _ => {
2954 #[cfg(feature = "logging")]
2955 info!(target: "stdout", "remove a {} message from the message queue", chat_completions_messages[0].role());
2956
2957 chat_completions_messages.remove(0);
2958 }
2959 }
2960
2961 continue;
2962 }
2963 false => return Ok((prompt, ctx_size - max_prompt_tokens, tool_use)),
2964 }
2965 }
2966}
2967
2968fn set_prompt(model_name: Option<&String>, prompt: impl AsRef<str>) -> Result<(), LlamaCoreError> {
2969 let chat_graphs = match CHAT_GRAPHS.get() {
2970 Some(chat_graphs) => chat_graphs,
2971 None => {
2972 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
2973
2974 #[cfg(feature = "logging")]
2975 error!(target: "stdout", "{}", &err_msg);
2976
2977 return Err(LlamaCoreError::Operation(err_msg.into()));
2978 }
2979 };
2980
2981 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
2982 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
2983
2984 #[cfg(feature = "logging")]
2985 error!(target: "stdout", "{}", &err_msg);
2986
2987 LlamaCoreError::Operation(err_msg)
2988 })?;
2989
2990 match model_name {
2991 Some(model_name) => {
2992 #[cfg(feature = "logging")]
2993 info!(target: "stdout", "Set prompt to the chat model named {model_name}");
2994
2995 match chat_graphs.contains_key(model_name) {
2996 true => {
2997 let graph = chat_graphs.get_mut(model_name).unwrap();
2998 let tensor_data = prompt.as_ref().as_bytes().to_vec();
2999 set_tensor_data_u8(graph, 0, &tensor_data)
3000 }
3001 false => match chat_graphs.iter_mut().next() {
3002 Some((_, graph)) => {
3003 let tensor_data = prompt.as_ref().as_bytes().to_vec();
3004 set_tensor_data_u8(graph, 0, &tensor_data)
3005 }
3006 None => {
3007 let err_msg = "There is no model available in the chat graphs.";
3008
3009 #[cfg(feature = "logging")]
3010 error!(target: "stdout", "{}", &err_msg);
3011
3012 Err(LlamaCoreError::Operation(err_msg.into()))
3013 }
3014 },
3015 }
3016 }
3017 None => {
3018 #[cfg(feature = "logging")]
3019 info!(target: "stdout", "Set prompt to the default chat model.");
3020
3021 match chat_graphs.iter_mut().next() {
3022 Some((_, graph)) => {
3023 let tensor_data = prompt.as_ref().as_bytes().to_vec();
3024 set_tensor_data_u8(graph, 0, &tensor_data)
3025 }
3026 None => {
3027 let err_msg = "There is no model available in the chat graphs while trying to set prompt to the default model.";
3028
3029 #[cfg(feature = "logging")]
3030 error!(target: "stdout", "{err_msg}");
3031
3032 Err(LlamaCoreError::Operation(err_msg.into()))
3033 }
3034 }
3035 }
3036 }
3037}
3038
3039/// Get a copy of the metadata of the model.
3040fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlMetadata, LlamaCoreError> {
3041 let chat_graphs = match CHAT_GRAPHS.get() {
3042 Some(chat_graphs) => chat_graphs,
3043 None => {
3044 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3045
3046 #[cfg(feature = "logging")]
3047 error!(target: "stdout", "{err_msg}");
3048
3049 return Err(LlamaCoreError::Operation(err_msg.into()));
3050 }
3051 };
3052
3053 let chat_graphs = chat_graphs.lock().map_err(|e| {
3054 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3055
3056 #[cfg(feature = "logging")]
3057 error!(target: "stdout", "{}", &err_msg);
3058
3059 LlamaCoreError::Operation(err_msg)
3060 })?;
3061
3062 match model_name {
3063 Some(model_name) => match chat_graphs.contains_key(model_name) {
3064 true => {
3065 let graph = chat_graphs.get(model_name).unwrap();
3066 Ok(graph.metadata.clone())
3067 }
3068 false => match chat_graphs.iter().next() {
3069 Some((_, graph)) => Ok(graph.metadata.clone()),
3070 None => {
3071 let err_msg = "There is no model available in the chat graphs.";
3072
3073 #[cfg(feature = "logging")]
3074 error!(target: "stdout", "{}", &err_msg);
3075
3076 Err(LlamaCoreError::Operation(err_msg.into()))
3077 }
3078 },
3079 },
3080 None => match chat_graphs.iter().next() {
3081 Some((_, graph)) => Ok(graph.metadata.clone()),
3082 None => {
3083 let err_msg = "There is no model available in the chat graphs.";
3084
3085 #[cfg(feature = "logging")]
3086 error!(target: "stdout", "{err_msg}");
3087
3088 Err(LlamaCoreError::Operation(err_msg.into()))
3089 }
3090 },
3091 }
3092}
3093
3094fn update_model_metadata(
3095 model_name: Option<&String>,
3096 metadata: &GgmlMetadata,
3097) -> Result<(), LlamaCoreError> {
3098 let config = match serde_json::to_string(metadata) {
3099 Ok(config) => config,
3100 Err(e) => {
3101 let err_msg = format!("Fail to serialize metadata to a JSON string. {e}");
3102
3103 #[cfg(feature = "logging")]
3104 error!(target: "stdout", "{}", &err_msg);
3105
3106 return Err(LlamaCoreError::Operation(err_msg));
3107 }
3108 };
3109
3110 let chat_graphs = match CHAT_GRAPHS.get() {
3111 Some(chat_graphs) => chat_graphs,
3112 None => {
3113 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3114
3115 #[cfg(feature = "logging")]
3116 error!(target: "stdout", "{err_msg}");
3117
3118 return Err(LlamaCoreError::Operation(err_msg.into()));
3119 }
3120 };
3121
3122 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3123 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. Reason: {e}");
3124
3125 #[cfg(feature = "logging")]
3126 error!(target: "stdout", "{}", &err_msg);
3127
3128 LlamaCoreError::Operation(err_msg)
3129 })?;
3130
3131 match model_name {
3132 Some(model_name) => {
3133 match chat_graphs.contains_key(model_name) {
3134 true => {
3135 let graph = chat_graphs.get_mut(model_name).unwrap();
3136 // update metadata
3137 set_tensor_data_u8(graph, 1, config.as_bytes())
3138 }
3139 false => match chat_graphs.iter_mut().next() {
3140 Some((_, graph)) => {
3141 // update metadata
3142 set_tensor_data_u8(graph, 1, config.as_bytes())
3143 }
3144 None => {
3145 let err_msg = "There is no model available in the chat graphs.";
3146
3147 #[cfg(feature = "logging")]
3148 error!(target: "stdout", "{}", &err_msg);
3149
3150 Err(LlamaCoreError::Operation(err_msg.into()))
3151 }
3152 },
3153 }
3154 }
3155 None => {
3156 match chat_graphs.iter_mut().next() {
3157 Some((_, graph)) => {
3158 // update metadata
3159 set_tensor_data_u8(graph, 1, config.as_bytes())
3160 }
3161 None => {
3162 let err_msg = "There is no model available in the chat graphs.";
3163
3164 #[cfg(feature = "logging")]
3165 error!(target: "stdout", "{err_msg}");
3166
3167 Err(LlamaCoreError::Operation(err_msg.into()))
3168 }
3169 }
3170 }
3171 }
3172}
3173
3174fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
3175 // get metadata
3176 let metadata = get_model_metadata(model_name)?;
3177
3178 // update model with the original metadata
3179 update_model_metadata(model_name, &metadata)
3180}
3181
3182#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3183#[allow(dead_code)]
3184enum ContextFullState {
3185 Message,
3186 Usage,
3187 Done,
3188 EndOfSequence,
3189}
3190
3191#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3192#[allow(dead_code)]
3193enum StreamState {
3194 Usage,
3195 NoUsage,
3196 Done,
3197 EndOfSequence,
3198}
3199
3200#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3201#[allow(dead_code)]
3202enum PromptTooLongState {
3203 Message,
3204 Usage,
3205 Done,
3206 EndOfSequence,
3207}
3208
3209struct ChatStream {
3210 // id: String,
3211 model: Option<String>,
3212 // include_usage: bool,
3213 context_full_state: ContextFullState,
3214 prompt_too_long_state: PromptTooLongState,
3215 stream_state: StreamState,
3216 cache: Option<VecDeque<String>>,
3217 is_waiting: bool,
3218 has_lock: bool,
3219}
3220impl ChatStream {
3221 fn new(
3222 model: Option<String>,
3223 // id: String,
3224 // include_usage: bool,
3225 cache: Option<Vec<String>>,
3226 ) -> Self {
3227 // Try to acquire lock
3228 let has_lock = CHAT_STREAM_ACTIVE
3229 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
3230 .is_ok();
3231
3232 #[cfg(feature = "logging")]
3233 if !has_lock {
3234 info!(target: "stdout", "Lock acquisition failed in ChatStream::new, creating with waiting status");
3235 }
3236
3237 ChatStream {
3238 // id,
3239 model,
3240 // include_usage,
3241 context_full_state: ContextFullState::Message,
3242 prompt_too_long_state: PromptTooLongState::Message,
3243 stream_state: StreamState::Usage,
3244 cache: cache.map(VecDeque::from),
3245 is_waiting: !has_lock,
3246 has_lock,
3247 }
3248 }
3249
3250 // Try to acquire lock, returns whether successful
3251 fn try_acquire_lock(&mut self) -> bool {
3252 if self.has_lock {
3253 return true;
3254 }
3255
3256 let acquired = CHAT_STREAM_ACTIVE
3257 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
3258 .is_ok();
3259
3260 if acquired {
3261 self.has_lock = true;
3262 self.is_waiting = false;
3263 }
3264
3265 acquired
3266 }
3267}
3268impl Drop for ChatStream {
3269 fn drop(&mut self) {
3270 // Clean up is only needed if we have the lock or if stream was actually used
3271 if self.has_lock || (self.cache.is_none() && !self.is_waiting) {
3272 #[cfg(feature = "logging")]
3273 info!(target: "stdout", "Cleaning up context for ChatStream");
3274
3275 match &self.model {
3276 Some(model_name) => {
3277 match CHAT_GRAPHS.get() {
3278 Some(chat_graphs) => {
3279 match chat_graphs.lock() {
3280 Ok(mut chat_graphs) => match chat_graphs.contains_key(model_name) {
3281 true => {
3282 let graph = chat_graphs.get_mut(model_name).unwrap();
3283
3284 // clean up the context
3285 if let Err(e) = graph.finish_single() {
3286 let err_msg = format!(
3287 "Failed to clean up the context. Reason: {e}"
3288 );
3289
3290 #[cfg(feature = "logging")]
3291 error!(target: "stdout", "{}", &err_msg);
3292
3293 #[cfg(not(feature = "logging"))]
3294 println!(
3295 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3296 &err_msg
3297 );
3298 }
3299 }
3300 false => match chat_graphs.iter_mut().next() {
3301 Some((_, graph)) => {
3302 // clean up the context
3303 if let Err(e) = graph.finish_single() {
3304 let err_msg = format!(
3305 "Failed to clean up the context. Reason: {e}"
3306 );
3307
3308 #[cfg(feature = "logging")]
3309 error!(target: "stdout", "{}", &err_msg);
3310
3311 #[cfg(not(feature = "logging"))]
3312 println!(
3313 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3314 &err_msg
3315 );
3316 }
3317 }
3318 None => {
3319 let err_msg =
3320 "There is no model available in the chat graphs.";
3321
3322 #[cfg(feature = "logging")]
3323 error!(target: "stdout", "{}", &err_msg);
3324
3325 #[cfg(not(feature = "logging"))]
3326 println!(
3327 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3328 &err_msg
3329 );
3330 }
3331 },
3332 },
3333 Err(e) => {
3334 let err_msg =
3335 format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3336
3337 #[cfg(feature = "logging")]
3338 error!(target: "stdout", "{}", &err_msg);
3339
3340 #[cfg(not(feature = "logging"))]
3341 println!(
3342 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3343 &err_msg
3344 );
3345 }
3346 }
3347 }
3348 None => {
3349 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3350
3351 #[cfg(feature = "logging")]
3352 error!(target: "stdout", "{}", &err_msg);
3353
3354 #[cfg(not(feature = "logging"))]
3355 println!(
3356 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3357 &err_msg
3358 );
3359 }
3360 };
3361 }
3362 None => {
3363 match CHAT_GRAPHS.get() {
3364 Some(chat_graphs) => {
3365 match chat_graphs.lock() {
3366 Ok(mut chat_graphs) => match chat_graphs.iter_mut().next() {
3367 Some((_, graph)) => {
3368 // clean up the context
3369 if let Err(e) = graph.finish_single() {
3370 let err_msg = format!(
3371 "Failed to clean up the context. Reason: {e}"
3372 );
3373
3374 #[cfg(feature = "logging")]
3375 error!(target: "stdout", "{}", &err_msg);
3376
3377 #[cfg(not(feature = "logging"))]
3378 println!(
3379 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3380 &err_msg
3381 );
3382 }
3383 }
3384 None => {
3385 let err_msg =
3386 "There is no model available in the chat graphs.";
3387
3388 #[cfg(feature = "logging")]
3389 error!(target: "stdout", "{err_msg}");
3390
3391 #[cfg(not(feature = "logging"))]
3392 println!(
3393 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3394 err_msg
3395 );
3396 }
3397 },
3398 Err(e) => {
3399 let err_msg =
3400 format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3401
3402 #[cfg(feature = "logging")]
3403 error!(target: "stdout", "{}", &err_msg);
3404
3405 #[cfg(not(feature = "logging"))]
3406 println!(
3407 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3408 &err_msg
3409 );
3410 }
3411 }
3412 }
3413 None => {
3414 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3415
3416 #[cfg(feature = "logging")]
3417 error!(target: "stdout", "{}", &err_msg);
3418
3419 #[cfg(not(feature = "logging"))]
3420 println!(
3421 "[ERROR][llama_core] Failed to clean up the context. Reason: {}",
3422 &err_msg
3423 );
3424 }
3425 };
3426 }
3427 }
3428
3429 #[cfg(feature = "logging")]
3430 info!(target: "stdout", "Model context cleanup done!");
3431 }
3432
3433 // reset the model metadata
3434 if let Err(e) = reset_model_metadata(self.model.as_ref()) {
3435 let err_msg = format!("Fail to reset model metadata. Reason: {e}");
3436
3437 #[cfg(feature = "logging")]
3438 error!(target: "stdout", "{}", &err_msg);
3439
3440 #[cfg(not(feature = "logging"))]
3441 println!("[ERROR][llama_core] {}", &err_msg);
3442 }
3443 #[cfg(feature = "logging")]
3444 info!(target: "stdout", "Model metadata reset done!");
3445
3446 // When dropping a ChatStream that held the lock, check if there are waiting streams
3447 if self.has_lock {
3448 // Reset the atomic flag
3449 CHAT_STREAM_ACTIVE.store(false, Ordering::SeqCst);
3450
3451 #[cfg(feature = "logging")]
3452 info!(target: "stdout", "Lock from ChatStream released");
3453
3454 // Wake up waiting streams
3455 if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3456 if let Some(waker) = queue.pop_front() {
3457 #[cfg(feature = "logging")]
3458 info!(target: "stdout", "Waking up a waiting ChatStream");
3459
3460 waker.wake();
3461 }
3462 }
3463 }
3464 }
3465}
3466impl futures::Stream for ChatStream {
3467 type Item = Result<String, LlamaCoreError>;
3468
3469 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
3470 let this = self.get_mut();
3471
3472 // If this is a waiting stream, try to acquire the lock
3473 if this.is_waiting {
3474 if !this.try_acquire_lock() {
3475 // Store the waker to be notified when the lock becomes available
3476 if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3477 // Remove any previous instance of this waker
3478 queue.retain(|w| !w.will_wake(cx.waker()));
3479 // Add this waker to the queue
3480 queue.push_back(cx.waker().clone());
3481
3482 #[cfg(feature = "logging")]
3483 debug!(target: "stdout", "ChatStream is waiting for lock, added waker to queue");
3484 }
3485
3486 return Poll::Pending;
3487 }
3488
3489 #[cfg(feature = "logging")]
3490 info!(target: "stdout", "ChatStream acquired lock and is now active");
3491 // If we got here, we successfully acquired the lock and can proceed
3492 }
3493
3494 // Ensure we still have the lock
3495 if !this.has_lock && !this.try_acquire_lock() {
3496 // Lost the lock, need to wait
3497 this.is_waiting = true;
3498
3499 // Register waker to be notified when lock is available
3500 if let Ok(mut queue) = get_chat_stream_waker_queue().lock() {
3501 queue.retain(|w| !w.will_wake(cx.waker()));
3502 queue.push_back(cx.waker().clone());
3503 }
3504
3505 return Poll::Pending;
3506 }
3507
3508 if let Some(cache) = &mut this.cache {
3509 let x = cache.pop_front();
3510
3511 #[cfg(feature = "logging")]
3512 info!(target: "stdout", "Get the next item from the cache for ChatStream: {:?}", &x);
3513
3514 match x {
3515 Some(x) => Poll::Ready(Some(Ok(x))),
3516 None => Poll::Ready(None),
3517 }
3518 } else {
3519 let res = compute_stream(
3520 this.model.clone(),
3521 &mut this.prompt_too_long_state,
3522 &mut this.context_full_state,
3523 &mut this.stream_state,
3524 );
3525
3526 match res {
3527 Ok(x) => {
3528 #[cfg(feature = "logging")]
3529 info!(target: "stdout", "next item for ChatStream: {}", &x);
3530
3531 if x != "[GGML] End of sequence" && !x.is_empty() {
3532 Poll::Ready(Some(Ok(x)))
3533 } else {
3534 // stopped
3535 Poll::Ready(None)
3536 }
3537 }
3538 Err(e) => Poll::Ready(Some(Err(e))),
3539 }
3540 }
3541 }
3542}
3543
3544/// Helper function to get or initialize the waker queue for waiting ChatStreams
3545fn get_chat_stream_waker_queue() -> &'static Mutex<VecDeque<Waker>> {
3546 CHAT_STREAM_WAKER_QUEUE.get_or_init(|| {
3547 #[cfg(feature = "logging")]
3548 info!(target: "stdout", "Initializing ChatStream waker queue");
3549 Mutex::new(VecDeque::new())
3550 })
3551}
3552
3553#[allow(unused_variables)]
3554fn compute_stream(
3555 model_name: Option<String>,
3556 // id: String,
3557 // include_usage: bool,
3558 prompt_too_long_state: &mut PromptTooLongState,
3559 context_full_state: &mut ContextFullState,
3560 stream_state: &mut StreamState,
3561) -> Result<String, LlamaCoreError> {
3562 {
3563 // #[cfg(feature = "logging")]
3564 // info!(target: "stdout", "Computing stream chunk for ChatStream {}", &id);
3565
3566 // #[cfg(feature = "logging")]
3567 // debug!(target: "stdout", "prompt_too_long_state: {:?}", *prompt_too_long_state);
3568 // #[cfg(feature = "logging")]
3569 // debug!(target: "stdout", "context_full_state: {:?}", *context_full_state);
3570 // #[cfg(feature = "logging")]
3571 // debug!(target: "stdout", "stream_state: {:?}", *stream_state);
3572
3573 // if *prompt_too_long_state == PromptTooLongState::EndOfSequence
3574 // || *context_full_state == ContextFullState::EndOfSequence
3575 // || *stream_state == StreamState::EndOfSequence
3576 // {
3577 // #[cfg(feature = "logging")]
3578 // info!(target: "stdout", "Return the chat stream chunk!");
3579
3580 // return Ok("[GGML] End of sequence".to_string());
3581 // }
3582
3583 // let chat_graphs = match CHAT_GRAPHS.get() {
3584 // Some(chat_graphs) => chat_graphs,
3585 // None => {
3586 // let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
3587
3588 // #[cfg(feature = "logging")]
3589 // error!(target: "stdout", "{}", &err_msg);
3590
3591 // return Err(LlamaCoreError::Operation(err_msg.into()));
3592 // }
3593 // };
3594
3595 // // We're already holding the ChatStream lock, so we know we have exclusive access to the graph
3596 // let mut chat_graphs = chat_graphs.lock().map_err(|e| {
3597 // let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
3598
3599 // #[cfg(feature = "logging")]
3600 // error!(target: "stdout", "{}", &err_msg);
3601
3602 // LlamaCoreError::Operation(err_msg)
3603 // })?;
3604
3605 // // Get the graph based on model name
3606 // let res = match &model_name {
3607 // Some(model_name) => {
3608 // match chat_graphs.contains_key(model_name) {
3609 // true => {
3610 // let graph = chat_graphs.get_mut(model_name).unwrap();
3611 // // compute
3612 // match graph.compute_single() {
3613 // Ok(_) => {
3614 // #[cfg(feature = "logging")]
3615 // debug!(target: "stdout", "Compute the chat stream chunk successfully.");
3616
3617 // // Process according to state
3618 // match stream_state {
3619 // StreamState::Usage | StreamState::NoUsage => {
3620 // // Retrieve the output
3621 // let output_buffer =
3622 // get_output_buffer_single(graph, OUTPUT_TENSOR)?;
3623
3624 // #[cfg(feature = "logging")]
3625 // info!(target: "stdout", "retrieved the output buffer");
3626
3627 // // decode the output buffer to a utf8 string
3628 // let output = match String::from_utf8(output_buffer.clone()) {
3629 // Ok(token) => token,
3630 // Err(_) => {
3631 // let mutex = CACHED_UTF8_ENCODINGS
3632 // .get_or_init(|| Mutex::new(Vec::new()));
3633 // let mut cached_encodings = mutex.lock().map_err(|e| {
3634 // let err_msg = format!(
3635 // "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
3636 // );
3637
3638 // #[cfg(feature = "logging")]
3639 // error!(target: "stdout", "{}", &err_msg);
3640
3641 // LlamaCoreError::Operation(err_msg)
3642 // })?;
3643
3644 // // cache the bytes for future decoding
3645 // cached_encodings.extend_from_slice(&output_buffer[..]);
3646
3647 // match String::from_utf8(cached_encodings.to_vec()) {
3648 // Ok(token) => {
3649 // // clear CACHED_UTF8_ENCODINGS
3650 // cached_encodings.clear();
3651
3652 // token
3653 // }
3654 // Err(e) => {
3655 // // TODO This is a temp check. In case, infinite cached encodings happen.
3656 // if cached_encodings.len() > 4 {
3657 // let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
3658
3659 // #[cfg(feature = "logging")]
3660 // error!(target: "stdout", "{}", &err_msg);
3661
3662 // #[cfg(feature = "logging")]
3663 // error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
3664
3665 // // let token = String::from_utf8_lossy(
3666 // // &cached_encodings,
3667 // // )
3668 // // .to_string();
3669
3670 // // clear CACHED_UTF8_ENCODINGS
3671 // cached_encodings.clear();
3672
3673 // String::from("")
3674 // } else {
3675 // let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
3676
3677 // #[cfg(feature = "logging")]
3678 // warn!(target: "stdout", "{}", &warn_msg);
3679
3680 // String::from("")
3681 // }
3682 // }
3683 // }
3684 // }
3685 // };
3686
3687 // #[cfg(feature = "logging")]
3688 // info!(target: "stdout", "decoded the output buffer");
3689
3690 // let created = SystemTime::now()
3691 // .duration_since(std::time::UNIX_EPOCH)
3692 // .map_err(|e| {
3693 // let err_msg = format!(
3694 // "Failed to get the current time. Reason: {e}"
3695 // );
3696
3697 // #[cfg(feature = "logging")]
3698 // error!(target: "stdout", "{}", &err_msg);
3699
3700 // LlamaCoreError::Operation(err_msg)
3701 // })?;
3702
3703 // let chat_completion_chunk = ChatCompletionChunk {
3704 // id,
3705 // object: "chat.completion.chunk".to_string(),
3706 // created: created.as_secs(),
3707 // model: graph.name().to_owned(),
3708 // system_fingerprint: "fp_44709d6fcb".to_string(),
3709 // choices: vec![ChatCompletionChunkChoice {
3710 // index: 0,
3711 // delta: ChatCompletionChunkChoiceDelta {
3712 // role: ChatCompletionRole::Assistant,
3713 // content: Some(output),
3714 // tool_calls: vec![],
3715 // },
3716 // logprobs: None,
3717 // finish_reason: None,
3718 // }],
3719 // usage: None,
3720 // };
3721
3722 // #[cfg(feature = "logging")]
3723 // info!(target: "stdout", "created chat completion chunk");
3724
3725 // // serialize chat completion chunk
3726 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
3727 // .map_err(|e| {
3728 // let err_msg = format!(
3729 // "Failed to serialize chat completion chunk. Reason: {e}"
3730 // );
3731
3732 // #[cfg(feature = "logging")]
3733 // error!(target: "stdout", "{}", &err_msg);
3734
3735 // LlamaCoreError::Operation(err_msg)
3736 // })?;
3737
3738 // Ok(format!("data: {chunk_str}\n\n"))
3739 // }
3740 // StreamState::Done => {
3741 // *stream_state = StreamState::EndOfSequence;
3742
3743 // Ok("data: [DONE]\n\n".to_string())
3744 // }
3745 // StreamState::EndOfSequence => {
3746 // Ok("[GGML] End of sequence".to_string())
3747 // }
3748 // }
3749 // }
3750 // Err(wasmedge_wasi_nn::Error::BackendError(
3751 // wasmedge_wasi_nn::BackendError::EndOfSequence,
3752 // )) => {
3753 // #[cfg(feature = "logging")]
3754 // debug!(target: "stdout", "End of sequence");
3755
3756 // match stream_state {
3757 // StreamState::Usage => {
3758 // *stream_state = StreamState::Done;
3759
3760 // // retrieve the number of prompt and completion tokens
3761 // let token_info = get_token_info_by_graph(graph)?;
3762
3763 // let usage = Some(Usage {
3764 // prompt_tokens: token_info.prompt_tokens,
3765 // completion_tokens: token_info.completion_tokens,
3766 // total_tokens: token_info.prompt_tokens
3767 // + token_info.completion_tokens,
3768 // });
3769
3770 // #[cfg(feature = "logging")]
3771 // info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
3772
3773 // let created = SystemTime::now()
3774 // .duration_since(std::time::UNIX_EPOCH)
3775 // .map_err(|e| {
3776 // let err_msg = format!(
3777 // "Failed to get the current time. Reason: {e}"
3778 // );
3779
3780 // #[cfg(feature = "logging")]
3781 // error!(target: "stdout", "{}", &err_msg);
3782
3783 // LlamaCoreError::Operation(err_msg)
3784 // })?;
3785
3786 // let chat_completion_chunk = ChatCompletionChunk {
3787 // id,
3788 // object: "chat.completion.chunk".to_string(),
3789 // created: created.as_secs(),
3790 // model: graph.name().to_owned(),
3791 // system_fingerprint: "fp_44709d6fcb".to_string(),
3792 // choices: vec![],
3793 // usage,
3794 // };
3795
3796 // // serialize chat completion chunk
3797 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
3798 // .map_err(|e| {
3799 // let err_msg = format!(
3800 // "Failed to serialize chat completion chunk. Reason: {e}"
3801 // );
3802
3803 // #[cfg(feature = "logging")]
3804 // error!(target: "stdout", "{}", &err_msg);
3805
3806 // LlamaCoreError::Operation(err_msg)
3807 // })?;
3808
3809 // Ok(format!("data: {chunk_str}\n\n"))
3810 // }
3811 // StreamState::Done | StreamState::NoUsage => {
3812 // *stream_state = StreamState::EndOfSequence;
3813
3814 // Ok("data: [DONE]\n\n".to_string())
3815 // }
3816 // StreamState::EndOfSequence => {
3817 // Ok("[GGML] End of sequence".to_string())
3818 // }
3819 // }
3820 // }
3821 // Err(wasmedge_wasi_nn::Error::BackendError(
3822 // wasmedge_wasi_nn::BackendError::ContextFull,
3823 // )) => {
3824 // #[cfg(feature = "logging")]
3825 // debug!(target: "stdout", "Context full");
3826
3827 // match context_full_state {
3828 // ContextFullState::Message => {
3829 // match include_usage {
3830 // true => *context_full_state = ContextFullState::Usage,
3831 // false => *context_full_state = ContextFullState::Done,
3832 // }
3833
3834 // let created = SystemTime::now()
3835 // .duration_since(std::time::UNIX_EPOCH)
3836 // .map_err(|e| {
3837 // let err_msg = format!(
3838 // "Failed to get the current time. Reason: {e}"
3839 // );
3840
3841 // #[cfg(feature = "logging")]
3842 // error!(target: "stdout", "{}", &err_msg);
3843
3844 // LlamaCoreError::Operation(err_msg)
3845 // })?;
3846
3847 // let chat_completion_chunk = ChatCompletionChunk {
3848 // id,
3849 // object: "chat.completion.chunk".to_string(),
3850 // created: created.as_secs(),
3851 // model: graph.name().to_owned(),
3852 // system_fingerprint: "fp_44709d6fcb".to_string(),
3853 // choices: vec![ChatCompletionChunkChoice {
3854 // index: 0,
3855 // delta: ChatCompletionChunkChoiceDelta {
3856 // role: ChatCompletionRole::Assistant,
3857 // content: Some(
3858 // "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
3859 // ),
3860 // tool_calls: vec![],
3861 // },
3862 // logprobs: None,
3863 // finish_reason: Some(FinishReason::length),
3864 // }],
3865 // usage: None,
3866 // };
3867
3868 // // serialize chat completion chunk
3869 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
3870 // .map_err(|e| {
3871 // let err_msg = format!(
3872 // "Failed to serialize chat completion chunk. Reason: {e}"
3873 // );
3874
3875 // #[cfg(feature = "logging")]
3876 // error!(target: "stdout", "{}", &err_msg);
3877
3878 // LlamaCoreError::Operation(err_msg)
3879 // })?;
3880
3881 // Ok(format!("data: {chunk_str}\n\n"))
3882 // }
3883 // ContextFullState::Usage => {
3884 // *context_full_state = ContextFullState::Done;
3885
3886 // // retrieve the number of prompt and completion tokens
3887 // let token_info = get_token_info_by_graph(graph)?;
3888
3889 // let usage = Some(Usage {
3890 // prompt_tokens: token_info.prompt_tokens,
3891 // completion_tokens: token_info.completion_tokens,
3892 // total_tokens: token_info.prompt_tokens
3893 // + token_info.completion_tokens,
3894 // });
3895
3896 // let created = SystemTime::now()
3897 // .duration_since(std::time::UNIX_EPOCH)
3898 // .map_err(|e| {
3899 // let err_msg = format!(
3900 // "Failed to get the current time. Reason: {e}"
3901 // );
3902
3903 // #[cfg(feature = "logging")]
3904 // error!(target: "stdout", "{}", &err_msg);
3905
3906 // LlamaCoreError::Operation(err_msg)
3907 // })?;
3908
3909 // let chat_completion_chunk = ChatCompletionChunk {
3910 // id,
3911 // object: "chat.completion.chunk".to_string(),
3912 // created: created.as_secs(),
3913 // model: graph.name().to_owned(),
3914 // system_fingerprint: "fp_44709d6fcb".to_string(),
3915 // choices: vec![],
3916 // usage,
3917 // };
3918
3919 // // serialize chat completion chunk
3920 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
3921 // .map_err(|e| {
3922 // let err_msg = format!(
3923 // "Failed to serialize chat completion chunk. Reason: {e}"
3924 // );
3925
3926 // #[cfg(feature = "logging")]
3927 // error!(target: "stdout", "{}", &err_msg);
3928
3929 // LlamaCoreError::Operation(err_msg)
3930 // })?;
3931
3932 // Ok(format!("data: {chunk_str}\n\n"))
3933 // }
3934 // ContextFullState::Done => {
3935 // *context_full_state = ContextFullState::EndOfSequence;
3936
3937 // Ok("data: [DONE]\n\n".to_string())
3938 // }
3939 // ContextFullState::EndOfSequence => {
3940 // Ok("[GGML] End of sequence".to_string())
3941 // }
3942 // }
3943 // }
3944 // Err(wasmedge_wasi_nn::Error::BackendError(
3945 // wasmedge_wasi_nn::BackendError::PromptTooLong,
3946 // )) => {
3947 // #[cfg(feature = "logging")]
3948 // debug!(target: "stdout", "Prompt too long");
3949
3950 // match prompt_too_long_state {
3951 // PromptTooLongState::Message => {
3952 // match include_usage {
3953 // true => *prompt_too_long_state = PromptTooLongState::Usage,
3954 // false => *prompt_too_long_state = PromptTooLongState::Done,
3955 // }
3956
3957 // let created = SystemTime::now()
3958 // .duration_since(std::time::UNIX_EPOCH)
3959 // .map_err(|e| {
3960 // let err_msg = format!(
3961 // "Failed to get the current time. Reason: {e}"
3962 // );
3963
3964 // #[cfg(feature = "logging")]
3965 // error!(target: "stdout", "{}", &err_msg);
3966
3967 // LlamaCoreError::Operation(err_msg)
3968 // })?;
3969
3970 // let chat_completion_chunk = ChatCompletionChunk {
3971 // id,
3972 // object: "chat.completion.chunk".to_string(),
3973 // created: created.as_secs(),
3974 // model: graph.name().to_owned(),
3975 // system_fingerprint: "fp_44709d6fcb".to_string(),
3976 // choices: vec![ChatCompletionChunkChoice {
3977 // index: 0,
3978 // delta: ChatCompletionChunkChoiceDelta {
3979 // role: ChatCompletionRole::Assistant,
3980 // content: None,
3981 // tool_calls: vec![],
3982 // },
3983 // logprobs: None,
3984 // finish_reason: Some(FinishReason::length),
3985 // }],
3986 // usage: None,
3987 // };
3988
3989 // // serialize chat completion chunk
3990 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
3991 // .map_err(|e| {
3992 // let err_msg = format!(
3993 // "Failed to serialize chat completion chunk. Reason: {e}"
3994 // );
3995
3996 // #[cfg(feature = "logging")]
3997 // error!(target: "stdout", "{}", &err_msg);
3998
3999 // LlamaCoreError::Operation(err_msg)
4000 // })?;
4001
4002 // Ok(format!("data: {chunk_str}\n\n"))
4003 // }
4004 // PromptTooLongState::Usage => {
4005 // *prompt_too_long_state = PromptTooLongState::Done;
4006
4007 // // retrieve the number of prompt and completion tokens
4008 // let token_info = get_token_info_by_graph(graph)?;
4009
4010 // let usage = Some(Usage {
4011 // prompt_tokens: token_info.prompt_tokens,
4012 // completion_tokens: token_info.completion_tokens,
4013 // total_tokens: token_info.prompt_tokens
4014 // + token_info.completion_tokens,
4015 // });
4016
4017 // let created = SystemTime::now()
4018 // .duration_since(std::time::UNIX_EPOCH)
4019 // .map_err(|e| {
4020 // let err_msg = format!(
4021 // "Failed to get the current time. Reason: {e}"
4022 // );
4023
4024 // #[cfg(feature = "logging")]
4025 // error!(target: "stdout", "{}", &err_msg);
4026
4027 // LlamaCoreError::Operation(err_msg)
4028 // })?;
4029
4030 // let chat_completion_chunk = ChatCompletionChunk {
4031 // id,
4032 // object: "chat.completion.chunk".to_string(),
4033 // created: created.as_secs(),
4034 // model: graph.name().to_owned(),
4035 // system_fingerprint: "fp_44709d6fcb".to_string(),
4036 // choices: vec![],
4037 // usage,
4038 // };
4039
4040 // // serialize chat completion chunk
4041 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
4042 // .map_err(|e| {
4043 // let err_msg = format!(
4044 // "Failed to serialize chat completion chunk. Reason: {e}"
4045 // );
4046
4047 // #[cfg(feature = "logging")]
4048 // error!(target: "stdout", "{}", &err_msg);
4049
4050 // LlamaCoreError::Operation(err_msg)
4051 // })?;
4052
4053 // Ok(format!("data: {chunk_str}\n\n"))
4054 // }
4055 // PromptTooLongState::Done => {
4056 // *prompt_too_long_state = PromptTooLongState::EndOfSequence;
4057
4058 // Ok("data: [DONE]\n\n".to_string())
4059 // }
4060 // PromptTooLongState::EndOfSequence => {
4061 // Ok("[GGML] End of sequence".to_string())
4062 // }
4063 // }
4064 // }
4065 // Err(e) => {
4066 // let err_msg =
4067 // format!("Failed to compute the chat completion. Reason: {e}");
4068
4069 // #[cfg(feature = "logging")]
4070 // error!(target: "stdout", "{}", &err_msg);
4071
4072 // Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4073 // err_msg,
4074 // )))
4075 // }
4076 // }
4077 // }
4078 // false => {
4079 // match chat_graphs.iter_mut().next() {
4080 // Some((_, graph)) => {
4081 // // compute
4082 // match graph.compute_single() {
4083 // Ok(_) => {
4084 // #[cfg(feature = "logging")]
4085 // debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4086
4087 // match stream_state {
4088 // StreamState::Usage | StreamState::NoUsage => {
4089 // // Retrieve the output
4090 // let output_buffer =
4091 // get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4092
4093 // #[cfg(feature = "logging")]
4094 // info!(target: "stdout", "retrieved the output buffer");
4095
4096 // // decode the output buffer to a utf8 string
4097 // let output = match String::from_utf8(
4098 // output_buffer.clone(),
4099 // ) {
4100 // Ok(token) => token,
4101 // Err(_) => {
4102 // let mutex = CACHED_UTF8_ENCODINGS
4103 // .get_or_init(|| Mutex::new(Vec::new()));
4104 // let mut cached_encodings = mutex.lock().map_err(|e| {
4105 // let err_msg = format!(
4106 // "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
4107 // );
4108
4109 // #[cfg(feature = "logging")]
4110 // error!(target: "stdout", "{}", &err_msg);
4111
4112 // LlamaCoreError::Operation(err_msg)
4113 // })?;
4114
4115 // // cache the bytes for future decoding
4116 // cached_encodings
4117 // .extend_from_slice(&output_buffer[..]);
4118
4119 // match String::from_utf8(
4120 // cached_encodings.to_vec(),
4121 // ) {
4122 // Ok(token) => {
4123 // // clear encodings
4124 // cached_encodings.clear();
4125
4126 // token
4127 // }
4128 // Err(e) => {
4129 // // TODO This is a temp check. In case, infinite cached encodings happen.
4130 // if cached_encodings.len() > 4 {
4131 // let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
4132
4133 // #[cfg(feature = "logging")]
4134 // error!(target: "stdout", "{}", &err_msg);
4135
4136 // #[cfg(feature = "logging")]
4137 // error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
4138
4139 // // let token =
4140 // // String::from_utf8_lossy(
4141 // // &cached_encodings,
4142 // // )
4143 // // .to_string();
4144
4145 // // clear CACHED_UTF8_ENCODINGS
4146 // cached_encodings.clear();
4147
4148 // String::from("")
4149 // } else {
4150 // let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
4151
4152 // #[cfg(feature = "logging")]
4153 // warn!(target: "stdout", "{}", &warn_msg);
4154
4155 // String::from("")
4156 // }
4157 // }
4158 // }
4159 // }
4160 // };
4161
4162 // #[cfg(feature = "logging")]
4163 // info!(target: "stdout", "decoded the output buffer");
4164
4165 // let created = SystemTime::now()
4166 // .duration_since(std::time::UNIX_EPOCH)
4167 // .map_err(|e| {
4168 // let err_msg = format!(
4169 // "Failed to get the current time. Reason: {e}"
4170 // );
4171
4172 // #[cfg(feature = "logging")]
4173 // error!(target: "stdout", "{}", &err_msg);
4174
4175 // LlamaCoreError::Operation(err_msg)
4176 // })?;
4177
4178 // let chat_completion_chunk = ChatCompletionChunk {
4179 // id,
4180 // object: "chat.completion.chunk".to_string(),
4181 // created: created.as_secs(),
4182 // model: graph.name().to_owned(),
4183 // system_fingerprint: "fp_44709d6fcb".to_string(),
4184 // choices: vec![ChatCompletionChunkChoice {
4185 // index: 0,
4186 // delta: ChatCompletionChunkChoiceDelta {
4187 // role: ChatCompletionRole::Assistant,
4188 // content: Some(output),
4189 // tool_calls: vec![],
4190 // },
4191 // logprobs: None,
4192 // finish_reason: None,
4193 // }],
4194 // usage: None,
4195 // };
4196
4197 // #[cfg(feature = "logging")]
4198 // info!(target: "stdout", "created chat completion chunk");
4199
4200 // // serialize chat completion chunk
4201 // let chunk_str =
4202 // serde_json::to_string(&chat_completion_chunk)
4203 // .map_err(|e| {
4204 // let err_msg = format!(
4205 // "Failed to serialize chat completion chunk. Reason: {e}"
4206 // );
4207
4208 // #[cfg(feature = "logging")]
4209 // error!(target: "stdout", "{}", &err_msg);
4210
4211 // LlamaCoreError::Operation(err_msg)
4212 // })?;
4213
4214 // Ok(format!("data: {chunk_str}\n\n"))
4215 // }
4216 // StreamState::Done => {
4217 // *stream_state = StreamState::EndOfSequence;
4218
4219 // Ok("data: [DONE]\n\n".to_string())
4220 // }
4221 // StreamState::EndOfSequence => {
4222 // Ok("[GGML] End of sequence".to_string())
4223 // }
4224 // }
4225 // }
4226 // Err(wasmedge_wasi_nn::Error::BackendError(
4227 // wasmedge_wasi_nn::BackendError::EndOfSequence,
4228 // )) => {
4229 // #[cfg(feature = "logging")]
4230 // debug!(target: "stdout", "End of sequence");
4231
4232 // match stream_state {
4233 // StreamState::Usage => {
4234 // *stream_state = StreamState::Done;
4235
4236 // // retrieve the number of prompt and completion tokens
4237 // let token_info = get_token_info_by_graph(graph)?;
4238
4239 // let usage = Some(Usage {
4240 // prompt_tokens: token_info.prompt_tokens,
4241 // completion_tokens: token_info.completion_tokens,
4242 // total_tokens: token_info.prompt_tokens
4243 // + token_info.completion_tokens,
4244 // });
4245
4246 // #[cfg(feature = "logging")]
4247 // info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4248
4249 // let created = SystemTime::now()
4250 // .duration_since(std::time::UNIX_EPOCH)
4251 // .map_err(|e| {
4252 // let err_msg = format!(
4253 // "Failed to get the current time. Reason: {e}"
4254 // );
4255
4256 // #[cfg(feature = "logging")]
4257 // error!(target: "stdout", "{}", &err_msg);
4258
4259 // LlamaCoreError::Operation(err_msg)
4260 // })?;
4261
4262 // let chat_completion_chunk = ChatCompletionChunk {
4263 // id,
4264 // object: "chat.completion.chunk".to_string(),
4265 // created: created.as_secs(),
4266 // model: graph.name().to_owned(),
4267 // system_fingerprint: "fp_44709d6fcb".to_string(),
4268 // choices: vec![],
4269 // usage,
4270 // };
4271
4272 // // serialize chat completion chunk
4273 // let chunk_str =
4274 // serde_json::to_string(&chat_completion_chunk)
4275 // .map_err(|e| {
4276 // let err_msg = format!(
4277 // "Failed to serialize chat completion chunk. Reason: {e}"
4278 // );
4279
4280 // #[cfg(feature = "logging")]
4281 // error!(target: "stdout", "{}", &err_msg);
4282
4283 // LlamaCoreError::Operation(err_msg)
4284 // })?;
4285
4286 // Ok(format!("data: {chunk_str}\n\n"))
4287 // }
4288 // StreamState::Done | StreamState::NoUsage => {
4289 // *stream_state = StreamState::EndOfSequence;
4290
4291 // Ok("data: [DONE]\n\n".to_string())
4292 // }
4293 // StreamState::EndOfSequence => {
4294 // Ok("[GGML] End of sequence".to_string())
4295 // }
4296 // }
4297 // }
4298 // Err(wasmedge_wasi_nn::Error::BackendError(
4299 // wasmedge_wasi_nn::BackendError::ContextFull,
4300 // )) => {
4301 // #[cfg(feature = "logging")]
4302 // debug!(target: "stdout", "Context full");
4303
4304 // match context_full_state {
4305 // ContextFullState::Message => {
4306 // match include_usage {
4307 // true => {
4308 // *context_full_state = ContextFullState::Usage
4309 // }
4310 // false => {
4311 // *context_full_state = ContextFullState::Done
4312 // }
4313 // }
4314
4315 // let created = SystemTime::now()
4316 // .duration_since(std::time::UNIX_EPOCH)
4317 // .map_err(|e| {
4318 // let err_msg = format!(
4319 // "Failed to get the current time. Reason: {e}"
4320 // );
4321
4322 // #[cfg(feature = "logging")]
4323 // error!(target: "stdout", "{}", &err_msg);
4324
4325 // LlamaCoreError::Operation(err_msg)
4326 // })?;
4327
4328 // let chat_completion_chunk = ChatCompletionChunk {
4329 // id,
4330 // object: "chat.completion.chunk".to_string(),
4331 // created: created.as_secs(),
4332 // model: graph.name().to_owned(),
4333 // system_fingerprint: "fp_44709d6fcb".to_string(),
4334 // choices: vec![ChatCompletionChunkChoice {
4335 // index: 0,
4336 // delta: ChatCompletionChunkChoiceDelta {
4337 // role: ChatCompletionRole::Assistant,
4338 // content: Some(
4339 // "<|WASMEDGE-GGML-CONTEXT-FULL|>"
4340 // .to_string(),
4341 // ),
4342 // tool_calls: vec![],
4343 // },
4344 // logprobs: None,
4345 // finish_reason: Some(FinishReason::length),
4346 // }],
4347 // usage: None,
4348 // };
4349
4350 // // serialize chat completion chunk
4351 // let chunk_str =
4352 // serde_json::to_string(&chat_completion_chunk)
4353 // .map_err(|e| {
4354 // let err_msg = format!(
4355 // "Failed to serialize chat completion chunk. Reason: {e}"
4356 // );
4357
4358 // #[cfg(feature = "logging")]
4359 // error!(target: "stdout", "{}", &err_msg);
4360
4361 // LlamaCoreError::Operation(err_msg)
4362 // })?;
4363
4364 // Ok(format!("data: {chunk_str}\n\n"))
4365 // }
4366 // ContextFullState::Usage => {
4367 // *context_full_state = ContextFullState::Done;
4368
4369 // // retrieve the number of prompt and completion tokens
4370 // let token_info = get_token_info_by_graph(graph)?;
4371
4372 // let usage = Some(Usage {
4373 // prompt_tokens: token_info.prompt_tokens,
4374 // completion_tokens: token_info.completion_tokens,
4375 // total_tokens: token_info.prompt_tokens
4376 // + token_info.completion_tokens,
4377 // });
4378
4379 // let created = SystemTime::now()
4380 // .duration_since(std::time::UNIX_EPOCH)
4381 // .map_err(|e| {
4382 // let err_msg = format!(
4383 // "Failed to get the current time. Reason: {e}"
4384 // );
4385
4386 // #[cfg(feature = "logging")]
4387 // error!(target: "stdout", "{}", &err_msg);
4388
4389 // LlamaCoreError::Operation(err_msg)
4390 // })?;
4391
4392 // let chat_completion_chunk = ChatCompletionChunk {
4393 // id,
4394 // object: "chat.completion.chunk".to_string(),
4395 // created: created.as_secs(),
4396 // model: graph.name().to_owned(),
4397 // system_fingerprint: "fp_44709d6fcb".to_string(),
4398 // choices: vec![],
4399 // usage,
4400 // };
4401
4402 // // serialize chat completion chunk
4403 // let chunk_str =
4404 // serde_json::to_string(&chat_completion_chunk)
4405 // .map_err(|e| {
4406 // let err_msg = format!(
4407 // "Failed to serialize chat completion chunk. Reason: {e}"
4408 // );
4409
4410 // #[cfg(feature = "logging")]
4411 // error!(target: "stdout", "{}", &err_msg);
4412
4413 // LlamaCoreError::Operation(err_msg)
4414 // })?;
4415
4416 // Ok(format!("data: {chunk_str}\n\n"))
4417 // }
4418 // ContextFullState::Done => {
4419 // *context_full_state = ContextFullState::EndOfSequence;
4420
4421 // Ok("data: [DONE]\n\n".to_string())
4422 // }
4423 // ContextFullState::EndOfSequence => {
4424 // Ok("[GGML] End of sequence".to_string())
4425 // }
4426 // }
4427 // }
4428 // Err(wasmedge_wasi_nn::Error::BackendError(
4429 // wasmedge_wasi_nn::BackendError::PromptTooLong,
4430 // )) => {
4431 // #[cfg(feature = "logging")]
4432 // debug!(target: "stdout", "Prompt too long");
4433
4434 // match prompt_too_long_state {
4435 // PromptTooLongState::Message => {
4436 // match include_usage {
4437 // true => {
4438 // *prompt_too_long_state =
4439 // PromptTooLongState::Usage
4440 // }
4441 // false => {
4442 // *prompt_too_long_state =
4443 // PromptTooLongState::Done
4444 // }
4445 // }
4446
4447 // let created = SystemTime::now()
4448 // .duration_since(std::time::UNIX_EPOCH)
4449 // .map_err(|e| {
4450 // let err_msg = format!(
4451 // "Failed to get the current time. Reason: {e}"
4452 // );
4453
4454 // #[cfg(feature = "logging")]
4455 // error!(target: "stdout", "{}", &err_msg);
4456
4457 // LlamaCoreError::Operation(err_msg)
4458 // })?;
4459
4460 // let chat_completion_chunk = ChatCompletionChunk {
4461 // id,
4462 // object: "chat.completion.chunk".to_string(),
4463 // created: created.as_secs(),
4464 // model: graph.name().to_owned(),
4465 // system_fingerprint: "fp_44709d6fcb".to_string(),
4466 // choices: vec![ChatCompletionChunkChoice {
4467 // index: 0,
4468 // delta: ChatCompletionChunkChoiceDelta {
4469 // role: ChatCompletionRole::Assistant,
4470 // content: None,
4471 // tool_calls: vec![],
4472 // },
4473 // logprobs: None,
4474 // finish_reason: Some(FinishReason::length),
4475 // }],
4476 // usage: None,
4477 // };
4478
4479 // // serialize chat completion chunk
4480 // let chunk_str =
4481 // serde_json::to_string(&chat_completion_chunk)
4482 // .map_err(|e| {
4483 // let err_msg = format!(
4484 // "Failed to serialize chat completion chunk. Reason: {e}"
4485 // );
4486
4487 // #[cfg(feature = "logging")]
4488 // error!(target: "stdout", "{}", &err_msg);
4489
4490 // LlamaCoreError::Operation(err_msg)
4491 // })?;
4492
4493 // Ok(format!("data: {chunk_str}\n\n"))
4494 // }
4495 // PromptTooLongState::Usage => {
4496 // *prompt_too_long_state = PromptTooLongState::Done;
4497
4498 // // retrieve the number of prompt and completion tokens
4499 // let token_info = get_token_info_by_graph(graph)?;
4500
4501 // let usage = Some(Usage {
4502 // prompt_tokens: token_info.prompt_tokens,
4503 // completion_tokens: token_info.completion_tokens,
4504 // total_tokens: token_info.prompt_tokens
4505 // + token_info.completion_tokens,
4506 // });
4507
4508 // let created = SystemTime::now()
4509 // .duration_since(std::time::UNIX_EPOCH)
4510 // .map_err(|e| {
4511 // let err_msg = format!(
4512 // "Failed to get the current time. Reason: {e}"
4513 // );
4514
4515 // #[cfg(feature = "logging")]
4516 // error!(target: "stdout", "{}", &err_msg);
4517
4518 // LlamaCoreError::Operation(err_msg)
4519 // })?;
4520
4521 // let chat_completion_chunk = ChatCompletionChunk {
4522 // id,
4523 // object: "chat.completion.chunk".to_string(),
4524 // created: created.as_secs(),
4525 // model: graph.name().to_owned(),
4526 // system_fingerprint: "fp_44709d6fcb".to_string(),
4527 // choices: vec![],
4528 // usage,
4529 // };
4530
4531 // // serialize chat completion chunk
4532 // let chunk_str =
4533 // serde_json::to_string(&chat_completion_chunk)
4534 // .map_err(|e| {
4535 // let err_msg = format!(
4536 // "Failed to serialize chat completion chunk. Reason: {e}"
4537 // );
4538
4539 // #[cfg(feature = "logging")]
4540 // error!(target: "stdout", "{}", &err_msg);
4541
4542 // LlamaCoreError::Operation(err_msg)
4543 // })?;
4544
4545 // Ok(format!("data: {chunk_str}\n\n"))
4546 // }
4547 // PromptTooLongState::Done => {
4548 // *prompt_too_long_state =
4549 // PromptTooLongState::EndOfSequence;
4550
4551 // Ok("data: [DONE]\n\n".to_string())
4552 // }
4553 // PromptTooLongState::EndOfSequence => {
4554 // Ok("[GGML] End of sequence".to_string())
4555 // }
4556 // }
4557 // }
4558 // Err(e) => {
4559 // let err_msg = format!(
4560 // "Failed to compute the chat completion. Reason: {e}"
4561 // );
4562
4563 // #[cfg(feature = "logging")]
4564 // error!(target: "stdout", "{}", &err_msg);
4565
4566 // Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
4567 // err_msg,
4568 // )))
4569 // }
4570 // }
4571 // }
4572 // None => {
4573 // let err_msg = "There is no model available in the chat graphs.";
4574
4575 // #[cfg(feature = "logging")]
4576 // error!(target: "stdout", "{}", &err_msg);
4577
4578 // Err(LlamaCoreError::Operation(err_msg.into()))
4579 // }
4580 // }
4581 // }
4582 // }
4583 // }
4584 // None => {
4585 // match chat_graphs.iter_mut().next() {
4586 // Some((_, graph)) => {
4587 // // compute
4588 // match graph.compute_single() {
4589 // Ok(_) => {
4590 // #[cfg(feature = "logging")]
4591 // debug!(target: "stdout", "Compute the chat stream chunk successfully.");
4592
4593 // match stream_state {
4594 // StreamState::Usage | StreamState::NoUsage => {
4595 // // Retrieve the output
4596 // let output_buffer =
4597 // get_output_buffer_single(graph, OUTPUT_TENSOR)?;
4598
4599 // #[cfg(feature = "logging")]
4600 // info!(target: "stdout", "retrieved the output buffer");
4601
4602 // // decode the output buffer to a utf8 string
4603 // let output = match String::from_utf8(output_buffer.clone()) {
4604 // Ok(token) => token,
4605 // Err(_) => {
4606 // let mutex = CACHED_UTF8_ENCODINGS
4607 // .get_or_init(|| Mutex::new(Vec::new()));
4608 // let mut cached_encodings = mutex.lock().map_err(|e| {
4609 // let err_msg = format!(
4610 // "Fail to acquire the lock of `UTF8_ENCODINGS`. Reason: {e}"
4611 // );
4612
4613 // #[cfg(feature = "logging")]
4614 // error!(target: "stdout", "{}", &err_msg);
4615
4616 // LlamaCoreError::Operation(err_msg)
4617 // })?;
4618
4619 // cached_encodings.extend_from_slice(&output_buffer[..]);
4620
4621 // match String::from_utf8(cached_encodings.to_vec()) {
4622 // Ok(token) => {
4623 // // clear encodings
4624 // cached_encodings.clear();
4625
4626 // token
4627 // }
4628 // Err(e) => {
4629 // // TODO This is a temp check. In case, infinite cached encodings happen.
4630 // if cached_encodings.len() > 4 {
4631 // let err_msg = format!("Fail to convert a vector of bytes to string. The length of the utf8 bytes exceeds 4. {e}");
4632
4633 // #[cfg(feature = "logging")]
4634 // error!(target: "stdout", "{}", &err_msg);
4635
4636 // #[cfg(feature = "logging")]
4637 // error!(target: "stdout", "The cached buffer: {:?}", &cached_encodings[..]);
4638
4639 // // let token = String::from_utf8_lossy(
4640 // // &cached_encodings,
4641 // // )
4642 // // .to_string();
4643
4644 // // clear CACHED_UTF8_ENCODINGS
4645 // cached_encodings.clear();
4646
4647 // String::from("")
4648 // } else {
4649 // let warn_msg = format!("Fail to convert a vector of bytes to string. {e}");
4650
4651 // #[cfg(feature = "logging")]
4652 // warn!(target: "stdout", "{}", &warn_msg);
4653
4654 // String::from("")
4655 // }
4656 // }
4657 // }
4658 // }
4659 // };
4660
4661 // #[cfg(feature = "logging")]
4662 // info!(target: "stdout", "decoded the output buffer");
4663
4664 // let created = SystemTime::now()
4665 // .duration_since(std::time::UNIX_EPOCH)
4666 // .map_err(|e| {
4667 // let err_msg = format!(
4668 // "Failed to get the current time. Reason: {e}"
4669 // );
4670
4671 // #[cfg(feature = "logging")]
4672 // error!(target: "stdout", "{}", &err_msg);
4673
4674 // LlamaCoreError::Operation(err_msg)
4675 // })?;
4676
4677 // let chat_completion_chunk = ChatCompletionChunk {
4678 // id,
4679 // object: "chat.completion.chunk".to_string(),
4680 // created: created.as_secs(),
4681 // model: graph.name().to_owned(),
4682 // system_fingerprint: "fp_44709d6fcb".to_string(),
4683 // choices: vec![ChatCompletionChunkChoice {
4684 // index: 0,
4685 // delta: ChatCompletionChunkChoiceDelta {
4686 // role: ChatCompletionRole::Assistant,
4687 // content: Some(output),
4688 // tool_calls: vec![],
4689 // },
4690 // logprobs: None,
4691 // finish_reason: None,
4692 // }],
4693 // usage: None,
4694 // };
4695
4696 // #[cfg(feature = "logging")]
4697 // info!(target: "stdout", "created chat completion chunk");
4698
4699 // // serialize chat completion chunk
4700 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
4701 // .map_err(|e| {
4702 // let err_msg = format!(
4703 // "Failed to serialize chat completion chunk. Reason: {e}"
4704 // );
4705
4706 // #[cfg(feature = "logging")]
4707 // error!(target: "stdout", "{}", &err_msg);
4708
4709 // LlamaCoreError::Operation(err_msg)
4710 // })?;
4711
4712 // Ok(format!("data: {chunk_str}\n\n"))
4713 // }
4714 // StreamState::Done => {
4715 // *stream_state = StreamState::EndOfSequence;
4716
4717 // Ok("data: [DONE]\n\n".to_string())
4718 // }
4719 // StreamState::EndOfSequence => {
4720 // Ok("[GGML] End of sequence".to_string())
4721 // }
4722 // }
4723 // }
4724 // Err(wasmedge_wasi_nn::Error::BackendError(
4725 // wasmedge_wasi_nn::BackendError::EndOfSequence,
4726 // )) => {
4727 // #[cfg(feature = "logging")]
4728 // debug!(target: "stdout", "End of sequence");
4729
4730 // match stream_state {
4731 // StreamState::Usage => {
4732 // *stream_state = StreamState::Done;
4733
4734 // // retrieve the number of prompt and completion tokens
4735 // let token_info = get_token_info_by_graph(graph)?;
4736
4737 // let usage = Some(Usage {
4738 // prompt_tokens: token_info.prompt_tokens,
4739 // completion_tokens: token_info.completion_tokens,
4740 // total_tokens: token_info.prompt_tokens
4741 // + token_info.completion_tokens,
4742 // });
4743
4744 // #[cfg(feature = "logging")]
4745 // info!(target: "stdout", "token_info: {} prompt tokens, {} completion tokens", token_info.prompt_tokens, token_info.completion_tokens);
4746
4747 // let created = SystemTime::now()
4748 // .duration_since(std::time::UNIX_EPOCH)
4749 // .map_err(|e| {
4750 // let err_msg = format!(
4751 // "Failed to get the current time. Reason: {e}"
4752 // );
4753
4754 // #[cfg(feature = "logging")]
4755 // error!(target: "stdout", "{}", &err_msg);
4756
4757 // LlamaCoreError::Operation(err_msg)
4758 // })?;
4759
4760 // let chat_completion_chunk = ChatCompletionChunk {
4761 // id,
4762 // object: "chat.completion.chunk".to_string(),
4763 // created: created.as_secs(),
4764 // model: graph.name().to_owned(),
4765 // system_fingerprint: "fp_44709d6fcb".to_string(),
4766 // choices: vec![],
4767 // usage,
4768 // };
4769
4770 // // serialize chat completion chunk
4771 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
4772 // .map_err(|e| {
4773 // let err_msg = format!(
4774 // "Failed to serialize chat completion chunk. Reason: {e}"
4775 // );
4776
4777 // #[cfg(feature = "logging")]
4778 // error!(target: "stdout", "{}", &err_msg);
4779
4780 // LlamaCoreError::Operation(err_msg)
4781 // })?;
4782
4783 // Ok(format!("data: {chunk_str}\n\n"))
4784 // }
4785 // StreamState::Done | StreamState::NoUsage => {
4786 // *stream_state = StreamState::EndOfSequence;
4787
4788 // Ok("data: [DONE]\n\n".to_string())
4789 // }
4790 // StreamState::EndOfSequence => {
4791 // Ok("[GGML] End of sequence".to_string())
4792 // }
4793 // }
4794 // }
4795 // Err(wasmedge_wasi_nn::Error::BackendError(
4796 // wasmedge_wasi_nn::BackendError::ContextFull,
4797 // )) => {
4798 // #[cfg(feature = "logging")]
4799 // debug!(target: "stdout", "Context full");
4800
4801 // match context_full_state {
4802 // ContextFullState::Message => {
4803 // match include_usage {
4804 // true => *context_full_state = ContextFullState::Usage,
4805 // false => *context_full_state = ContextFullState::Done,
4806 // }
4807
4808 // let created = SystemTime::now()
4809 // .duration_since(std::time::UNIX_EPOCH)
4810 // .map_err(|e| {
4811 // let err_msg = format!(
4812 // "Failed to get the current time. Reason: {e}"
4813 // );
4814
4815 // #[cfg(feature = "logging")]
4816 // error!(target: "stdout", "{}", &err_msg);
4817
4818 // LlamaCoreError::Operation(err_msg)
4819 // })?;
4820
4821 // let chat_completion_chunk = ChatCompletionChunk {
4822 // id,
4823 // object: "chat.completion.chunk".to_string(),
4824 // created: created.as_secs(),
4825 // model: graph.name().to_owned(),
4826 // system_fingerprint: "fp_44709d6fcb".to_string(),
4827 // choices: vec![ChatCompletionChunkChoice {
4828 // index: 0,
4829 // delta: ChatCompletionChunkChoiceDelta {
4830 // role: ChatCompletionRole::Assistant,
4831 // content: Some(
4832 // "<|WASMEDGE-GGML-CONTEXT-FULL|>".to_string(),
4833 // ),
4834 // tool_calls: vec![],
4835 // },
4836 // logprobs: None,
4837 // finish_reason: Some(FinishReason::length),
4838 // }],
4839 // usage: None,
4840 // };
4841
4842 // // serialize chat completion chunk
4843 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
4844 // .map_err(|e| {
4845 // let err_msg = format!(
4846 // "Failed to serialize chat completion chunk. Reason: {e}"
4847 // );
4848
4849 // #[cfg(feature = "logging")]
4850 // error!(target: "stdout", "{}", &err_msg);
4851
4852 // LlamaCoreError::Operation(err_msg)
4853 // })?;
4854
4855 // Ok(format!("data: {chunk_str}\n\n"))
4856 // }
4857 // ContextFullState::Usage => {
4858 // *context_full_state = ContextFullState::Done;
4859
4860 // // retrieve the number of prompt and completion tokens
4861 // let token_info = get_token_info_by_graph(graph)?;
4862
4863 // let usage = Some(Usage {
4864 // prompt_tokens: token_info.prompt_tokens,
4865 // completion_tokens: token_info.completion_tokens,
4866 // total_tokens: token_info.prompt_tokens
4867 // + token_info.completion_tokens,
4868 // });
4869
4870 // let created = SystemTime::now()
4871 // .duration_since(std::time::UNIX_EPOCH)
4872 // .map_err(|e| {
4873 // let err_msg = format!(
4874 // "Failed to get the current time. Reason: {e}"
4875 // );
4876
4877 // #[cfg(feature = "logging")]
4878 // error!(target: "stdout", "{}", &err_msg);
4879
4880 // LlamaCoreError::Operation(err_msg)
4881 // })?;
4882
4883 // let chat_completion_chunk = ChatCompletionChunk {
4884 // id,
4885 // object: "chat.completion.chunk".to_string(),
4886 // created: created.as_secs(),
4887 // model: graph.name().to_owned(),
4888 // system_fingerprint: "fp_44709d6fcb".to_string(),
4889 // choices: vec![],
4890 // usage,
4891 // };
4892
4893 // // serialize chat completion chunk
4894 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
4895 // .map_err(|e| {
4896 // let err_msg = format!(
4897 // "Failed to serialize chat completion chunk. Reason: {e}"
4898 // );
4899
4900 // #[cfg(feature = "logging")]
4901 // error!(target: "stdout", "{}", &err_msg);
4902
4903 // LlamaCoreError::Operation(err_msg)
4904 // })?;
4905
4906 // Ok(format!("data: {chunk_str}\n\n"))
4907 // }
4908 // ContextFullState::Done => {
4909 // *context_full_state = ContextFullState::EndOfSequence;
4910
4911 // Ok("data: [DONE]\n\n".to_string())
4912 // }
4913 // ContextFullState::EndOfSequence => {
4914 // Ok("[GGML] End of sequence".to_string())
4915 // }
4916 // }
4917 // }
4918 // Err(wasmedge_wasi_nn::Error::BackendError(
4919 // wasmedge_wasi_nn::BackendError::PromptTooLong,
4920 // )) => {
4921 // #[cfg(feature = "logging")]
4922 // debug!(target: "stdout", "Prompt too long");
4923
4924 // match prompt_too_long_state {
4925 // PromptTooLongState::Message => {
4926 // match include_usage {
4927 // true => *prompt_too_long_state = PromptTooLongState::Usage,
4928 // false => *prompt_too_long_state = PromptTooLongState::Done,
4929 // }
4930
4931 // let created = SystemTime::now()
4932 // .duration_since(std::time::UNIX_EPOCH)
4933 // .map_err(|e| {
4934 // let err_msg = format!(
4935 // "Failed to get the current time. Reason: {e}"
4936 // );
4937
4938 // #[cfg(feature = "logging")]
4939 // error!(target: "stdout", "{}", &err_msg);
4940
4941 // LlamaCoreError::Operation(err_msg)
4942 // })?;
4943
4944 // let chat_completion_chunk = ChatCompletionChunk {
4945 // id,
4946 // object: "chat.completion.chunk".to_string(),
4947 // created: created.as_secs(),
4948 // model: graph.name().to_owned(),
4949 // system_fingerprint: "fp_44709d6fcb".to_string(),
4950 // choices: vec![ChatCompletionChunkChoice {
4951 // index: 0,
4952 // delta: ChatCompletionChunkChoiceDelta {
4953 // role: ChatCompletionRole::Assistant,
4954 // content: None,
4955 // tool_calls: vec![],
4956 // },
4957 // logprobs: None,
4958 // finish_reason: Some(FinishReason::length),
4959 // }],
4960 // usage: None,
4961 // };
4962
4963 // // serialize chat completion chunk
4964 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
4965 // .map_err(|e| {
4966 // let err_msg = format!(
4967 // "Failed to serialize chat completion chunk. Reason: {e}"
4968 // );
4969
4970 // #[cfg(feature = "logging")]
4971 // error!(target: "stdout", "{}", &err_msg);
4972
4973 // LlamaCoreError::Operation(err_msg)
4974 // })?;
4975
4976 // Ok(format!("data: {chunk_str}\n\n"))
4977 // }
4978 // PromptTooLongState::Usage => {
4979 // *prompt_too_long_state = PromptTooLongState::Done;
4980
4981 // // retrieve the number of prompt and completion tokens
4982 // let token_info = get_token_info_by_graph(graph)?;
4983
4984 // let usage = Some(Usage {
4985 // prompt_tokens: token_info.prompt_tokens,
4986 // completion_tokens: token_info.completion_tokens,
4987 // total_tokens: token_info.prompt_tokens
4988 // + token_info.completion_tokens,
4989 // });
4990
4991 // let created = SystemTime::now()
4992 // .duration_since(std::time::UNIX_EPOCH)
4993 // .map_err(|e| {
4994 // let err_msg = format!(
4995 // "Failed to get the current time. Reason: {e}"
4996 // );
4997
4998 // #[cfg(feature = "logging")]
4999 // error!(target: "stdout", "{}", &err_msg);
5000
5001 // LlamaCoreError::Operation(err_msg)
5002 // })?;
5003
5004 // let chat_completion_chunk = ChatCompletionChunk {
5005 // id,
5006 // object: "chat.completion.chunk".to_string(),
5007 // created: created.as_secs(),
5008 // model: graph.name().to_owned(),
5009 // system_fingerprint: "fp_44709d6fcb".to_string(),
5010 // choices: vec![],
5011 // usage,
5012 // };
5013
5014 // // serialize chat completion chunk
5015 // let chunk_str = serde_json::to_string(&chat_completion_chunk)
5016 // .map_err(|e| {
5017 // let err_msg = format!(
5018 // "Failed to serialize chat completion chunk. Reason: {e}"
5019 // );
5020
5021 // #[cfg(feature = "logging")]
5022 // error!(target: "stdout", "{}", &err_msg);
5023
5024 // LlamaCoreError::Operation(err_msg)
5025 // })?;
5026
5027 // Ok(format!("data: {chunk_str}\n\n"))
5028 // }
5029 // PromptTooLongState::Done => {
5030 // *prompt_too_long_state = PromptTooLongState::EndOfSequence;
5031
5032 // Ok("data: [DONE]\n\n".to_string())
5033 // }
5034 // PromptTooLongState::EndOfSequence => {
5035 // Ok("[GGML] End of sequence".to_string())
5036 // }
5037 // }
5038 // }
5039 // Err(e) => {
5040 // let err_msg =
5041 // format!("Failed to compute the chat completion. Reason: {e}");
5042
5043 // #[cfg(feature = "logging")]
5044 // error!(target: "stdout", "{}", &err_msg);
5045
5046 // Err(LlamaCoreError::Backend(BackendError::ComputeSingle(
5047 // err_msg,
5048 // )))
5049 // }
5050 // }
5051 // }
5052 // None => {
5053 // let err_msg = "There is no model available in the chat graphs.";
5054
5055 // #[cfg(feature = "logging")]
5056 // error!(target: "stdout", "{}", &err_msg);
5057
5058 // Err(LlamaCoreError::Operation(err_msg.into()))
5059 // }
5060 // }
5061 // }
5062 // };
5063
5064 // #[cfg(feature = "logging")]
5065 // info!(target: "stdout", "Return the chat stream chunk!");
5066
5067 // res
5068 }
5069
5070 todo!("stream_chat_completion is not implemented yet");
5071}
5072
5073// #[allow(dead_code)]
5074// #[derive(Debug)]
5075// struct ParseResult {
5076// raw: String,
5077// content: Option<String>,
5078// tool_calls: Vec<ToolCall>,
5079// }
5080
5081/// Convert Input to a vector of ChatCompletionRequestMessage
5082/// Only handles InputItem::InputMessage, other variants are skipped with warnings
5083fn to_chat_messages(input: &Input) -> Result<Vec<ChatCompletionRequestMessage>, LlamaCoreError> {
5084 match input {
5085 Input::Text(text) => {
5086 let content = ChatCompletionUserMessageContent::Text(text.clone());
5087 let user_message = ChatCompletionRequestMessage::new_user_message(content, None);
5088 // Simple text converts to a user message
5089 Ok(vec![user_message])
5090 }
5091 Input::InputItemList(items) => {
5092 let mut messages = Vec::new();
5093 for item in items {
5094 match item {
5095 InputItem::InputMessage { content, role, .. } => {
5096 let message = input_message_to_chat_message(content, role)?;
5097 messages.push(message);
5098 }
5099 _ => {
5100 #[cfg(feature = "logging")]
5101 warn!(target: "stdout", "Skipping unsupported InputItem variant");
5102 }
5103 }
5104 }
5105 Ok(messages)
5106 }
5107 }
5108}
5109
5110/// Helper function to convert InputMessage to ChatCompletionRequestMessage
5111fn input_message_to_chat_message(
5112 content: &InputMessageContent,
5113 role: &str,
5114) -> Result<ChatCompletionRequestMessage, LlamaCoreError> {
5115 match role {
5116 "user" => {
5117 let content = input_message_content_to_chat_message_content(content)?;
5118 let content = ChatCompletionUserMessageContent::Text(content);
5119 Ok(ChatCompletionRequestMessage::new_user_message(
5120 content, None,
5121 ))
5122 }
5123 "assistant" => {
5124 let content = input_message_content_to_chat_message_content(content)?;
5125 Ok(ChatCompletionRequestMessage::new_assistant_message(
5126 Some(content),
5127 None,
5128 None,
5129 ))
5130 }
5131 "system" => {
5132 let content = input_message_content_to_chat_message_content(content)?;
5133 Ok(ChatCompletionRequestMessage::new_system_message(
5134 content, None,
5135 ))
5136 }
5137 "developer" => {
5138 let content = input_message_content_to_chat_message_content(content)?;
5139 Ok(ChatCompletionRequestMessage::new_developer_message(
5140 content, None,
5141 ))
5142 }
5143 _ => {
5144 let error_msg = format!("Unsupported role: {}", role);
5145
5146 #[cfg(feature = "logging")]
5147 error!(target: "stdout", "{}", &error_msg);
5148
5149 Err(LlamaCoreError::Operation(error_msg))
5150 }
5151 }
5152}
5153
5154/// Convert InputMessageContent to ChatCompletionUserMessageContent
5155fn input_message_content_to_chat_message_content(
5156 content: &InputMessageContent,
5157) -> Result<String, LlamaCoreError> {
5158 match content {
5159 InputMessageContent::Text(text) => Ok(text.clone()),
5160 InputMessageContent::InputItemContentList(_items) => {
5161 let error_msg = "Not support InputMessageContent::InputItemContentList";
5162
5163 #[cfg(feature = "logging")]
5164 error!(target: "stdout", "{}", error_msg);
5165
5166 Err(LlamaCoreError::Operation(error_msg.into()))
5167 }
5168 }
5169}