1use crate::{
4 error::{BackendError, LlamaCoreError},
5 metadata::ggml::GgmlMetadata,
6 running_mode,
7 utils::{get_output_buffer, get_token_info_by_graph},
8 Graph, RunningMode, CHAT_GRAPHS, OUTPUT_TENSOR,
9};
10use endpoints::{
11 common::{FinishReason, Usage},
12 completions::{CompletionChoice, CompletionObject, CompletionPrompt, CompletionRequest},
13};
14use std::time::SystemTime;
15
16pub async fn completions(request: &CompletionRequest) -> Result<CompletionObject, LlamaCoreError> {
18 #[cfg(feature = "logging")]
19 info!(target: "stdout", "Generate completions");
20
21 let running_mode = running_mode()?;
22 if !running_mode.contains(RunningMode::CHAT) {
23 let err_msg = "The completion is only supported in the chat mode.";
24
25 #[cfg(feature = "logging")]
26 error!(target: "stdout", "{}", err_msg);
27
28 return Err(LlamaCoreError::Operation(err_msg.to_string()));
29 }
30
31 let prompt = match &request.prompt {
32 CompletionPrompt::SingleText(prompt) => prompt.to_owned(),
33 CompletionPrompt::MultiText(prompts) => prompts.join(" "),
34 };
35
36 compute(prompt.trim(), request.model.as_ref())
37}
38
39fn compute(
40 prompt: impl AsRef<str>,
41 model_name: Option<&String>,
42) -> std::result::Result<CompletionObject, LlamaCoreError> {
43 #[cfg(feature = "logging")]
44 info!(target: "stdout", "Compute completions");
45
46 let chat_graphs = match CHAT_GRAPHS.get() {
47 Some(chat_graphs) => chat_graphs,
48 None => {
49 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
50
51 #[cfg(feature = "logging")]
52 error!(target: "stdout", "{}", err_msg);
53
54 return Err(LlamaCoreError::Operation(err_msg.into()));
55 }
56 };
57
58 let mut chat_graphs = chat_graphs.lock().map_err(|e| {
59 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
60
61 #[cfg(feature = "logging")]
62 error!(target: "stdout", "{}", &err_msg);
63
64 LlamaCoreError::Operation(err_msg)
65 })?;
66
67 match model_name {
68 Some(model_name) => match chat_graphs.contains_key(model_name) {
69 true => {
70 let graph = chat_graphs.get_mut(model_name).unwrap();
71 compute_by_graph(graph, prompt)
72 }
73 false => match chat_graphs.iter_mut().next() {
74 Some((_, graph)) => compute_by_graph(graph, prompt),
75 None => {
76 let err_msg = "There is no model available in the chat graphs.";
77
78 #[cfg(feature = "logging")]
79 error!(target: "stdout", "{}", &err_msg);
80
81 Err(LlamaCoreError::Operation(err_msg.into()))
82 }
83 },
84 },
85 None => match chat_graphs.iter_mut().next() {
86 Some((_, graph)) => compute_by_graph(graph, prompt),
87 None => {
88 let err_msg = "There is no model available in the chat graphs.";
89
90 #[cfg(feature = "logging")]
91 error!(target: "stdout", "{}", &err_msg);
92
93 Err(LlamaCoreError::Operation(err_msg.into()))
94 }
95 },
96 }
97}
98
99fn compute_by_graph(
101 graph: &mut Graph<GgmlMetadata>,
102 prompt: impl AsRef<str>,
103) -> std::result::Result<CompletionObject, LlamaCoreError> {
104 #[cfg(feature = "logging")]
105 info!(target: "stdout", "Compute completions by graph");
106
107 if graph.metadata.embeddings {
109 graph.metadata.embeddings = false;
110
111 #[cfg(feature = "logging")]
112 info!(target: "stdout", "The `embedding` field of metadata sets to false.");
113
114 graph.update_metadata()?;
115 }
116
117 let tensor_data = prompt.as_ref().as_bytes().to_vec();
119 graph
120 .set_input(0, wasmedge_wasi_nn::TensorType::U8, &[1], &tensor_data)
121 .map_err(|e| {
122 let err_msg = format!("Failed to set the input tensor. {}", e);
123
124 #[cfg(feature = "logging")]
125 error!(target: "stdout", "{}", &err_msg);
126
127 LlamaCoreError::Backend(BackendError::SetInput(err_msg))
128 })?;
129
130 graph.compute().map_err(|e| {
132 let err_msg = format!("Failed to execute the inference. {}", e);
133
134 #[cfg(feature = "logging")]
135 error!(target: "stdout", "{}", &err_msg);
136
137 LlamaCoreError::Backend(BackendError::Compute(err_msg))
138 })?;
139
140 let buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
142
143 let model_answer = String::from_utf8(buffer).map_err(|e| {
145 let err_msg = format!(
146 "Failed to decode the buffer of the inference result to a utf-8 string. {}",
147 e
148 );
149
150 #[cfg(feature = "logging")]
151 error!(target: "stdout", "{}", &err_msg);
152
153 LlamaCoreError::Operation(err_msg)
154 })?;
155 let answer = model_answer.trim();
156
157 let token_info = get_token_info_by_graph(graph)?;
159
160 #[cfg(feature = "logging")]
161 info!(target: "stdout", "Prompt tokens: {}, Completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
162
163 let created = SystemTime::now()
164 .duration_since(std::time::UNIX_EPOCH)
165 .map_err(|e| {
166 let err_msg = format!("Failed to get the current time. {}", e);
167
168 #[cfg(feature = "logging")]
169 error!(target: "stdout", "{}", &err_msg);
170
171 LlamaCoreError::Operation(err_msg)
172 })?;
173
174 #[cfg(feature = "logging")]
175 info!(target: "stdout", "Completions generated successfully.");
176
177 Ok(CompletionObject {
178 id: uuid::Uuid::new_v4().to_string(),
179 object: String::from("text_completion"),
180 created: created.as_secs(),
181 model: graph.name().to_string(),
182 choices: vec![CompletionChoice {
183 index: 0,
184 text: String::from(answer),
185 finish_reason: FinishReason::stop,
186 logprobs: None,
187 }],
188 usage: Usage {
189 prompt_tokens: token_info.prompt_tokens,
190 completion_tokens: token_info.completion_tokens,
191 total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
192 },
193 })
194}