llama_core/
completions.rs

1//! Define APIs for completions.
2
3use crate::{
4    error::{BackendError, LlamaCoreError},
5    metadata::ggml::GgmlMetadata,
6    running_mode,
7    utils::{get_output_buffer, get_token_info_by_graph},
8    Graph, RunningMode, CHAT_GRAPHS, OUTPUT_TENSOR,
9};
10use endpoints::{
11    common::{FinishReason, Usage},
12    completions::{CompletionChoice, CompletionObject, CompletionPrompt, CompletionRequest},
13};
14use std::time::SystemTime;
15
16/// Given a prompt, the model will return one or more predicted completions along with the probabilities of alternative tokens at each position.
17pub async fn completions(request: &CompletionRequest) -> Result<CompletionObject, LlamaCoreError> {
18    #[cfg(feature = "logging")]
19    info!(target: "stdout", "Generate completions");
20
21    let running_mode = running_mode()?;
22    if !running_mode.contains(RunningMode::CHAT) {
23        let err_msg = "The completion is only supported in the chat mode.";
24
25        #[cfg(feature = "logging")]
26        error!(target: "stdout", "{}", err_msg);
27
28        return Err(LlamaCoreError::Operation(err_msg.to_string()));
29    }
30
31    let prompt = match &request.prompt {
32        CompletionPrompt::SingleText(prompt) => prompt.to_owned(),
33        CompletionPrompt::MultiText(prompts) => prompts.join(" "),
34    };
35
36    compute(prompt.trim(), request.model.as_ref())
37}
38
39fn compute(
40    prompt: impl AsRef<str>,
41    model_name: Option<&String>,
42) -> std::result::Result<CompletionObject, LlamaCoreError> {
43    #[cfg(feature = "logging")]
44    info!(target: "stdout", "Compute completions");
45
46    let chat_graphs = match CHAT_GRAPHS.get() {
47        Some(chat_graphs) => chat_graphs,
48        None => {
49            let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
50
51            #[cfg(feature = "logging")]
52            error!(target: "stdout", "{}", err_msg);
53
54            return Err(LlamaCoreError::Operation(err_msg.into()));
55        }
56    };
57
58    let mut chat_graphs = chat_graphs.lock().map_err(|e| {
59        let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
60
61        #[cfg(feature = "logging")]
62        error!(target: "stdout", "{}", &err_msg);
63
64        LlamaCoreError::Operation(err_msg)
65    })?;
66
67    match model_name {
68        Some(model_name) => match chat_graphs.contains_key(model_name) {
69            true => {
70                let graph = chat_graphs.get_mut(model_name).unwrap();
71                compute_by_graph(graph, prompt)
72            }
73            false => match chat_graphs.iter_mut().next() {
74                Some((_, graph)) => compute_by_graph(graph, prompt),
75                None => {
76                    let err_msg = "There is no model available in the chat graphs.";
77
78                    #[cfg(feature = "logging")]
79                    error!(target: "stdout", "{}", &err_msg);
80
81                    Err(LlamaCoreError::Operation(err_msg.into()))
82                }
83            },
84        },
85        None => match chat_graphs.iter_mut().next() {
86            Some((_, graph)) => compute_by_graph(graph, prompt),
87            None => {
88                let err_msg = "There is no model available in the chat graphs.";
89
90                #[cfg(feature = "logging")]
91                error!(target: "stdout", "{}", &err_msg);
92
93                Err(LlamaCoreError::Operation(err_msg.into()))
94            }
95        },
96    }
97}
98
99/// Runs inference on the model with the given name and returns the output.
100fn compute_by_graph(
101    graph: &mut Graph<GgmlMetadata>,
102    prompt: impl AsRef<str>,
103) -> std::result::Result<CompletionObject, LlamaCoreError> {
104    #[cfg(feature = "logging")]
105    info!(target: "stdout", "Compute completions by graph");
106
107    // check if the `embedding` model is disabled or not
108    if graph.metadata.embeddings {
109        graph.metadata.embeddings = false;
110
111        #[cfg(feature = "logging")]
112        info!(target: "stdout", "The `embedding` field of metadata sets to false.");
113
114        graph.update_metadata()?;
115    }
116
117    // set input
118    let tensor_data = prompt.as_ref().as_bytes().to_vec();
119    graph
120        .set_input(0, wasmedge_wasi_nn::TensorType::U8, &[1], &tensor_data)
121        .map_err(|e| {
122            let err_msg = format!("Failed to set the input tensor. {}", e);
123
124            #[cfg(feature = "logging")]
125            error!(target: "stdout", "{}", &err_msg);
126
127            LlamaCoreError::Backend(BackendError::SetInput(err_msg))
128        })?;
129
130    // execute the inference
131    graph.compute().map_err(|e| {
132        let err_msg = format!("Failed to execute the inference. {}", e);
133
134        #[cfg(feature = "logging")]
135        error!(target: "stdout", "{}", &err_msg);
136
137        LlamaCoreError::Backend(BackendError::Compute(err_msg))
138    })?;
139
140    // Retrieve the output
141    let buffer = get_output_buffer(graph, OUTPUT_TENSOR)?;
142
143    // convert inference result to string
144    let model_answer = String::from_utf8(buffer).map_err(|e| {
145        let err_msg = format!(
146            "Failed to decode the buffer of the inference result to a utf-8 string. {}",
147            e
148        );
149
150        #[cfg(feature = "logging")]
151        error!(target: "stdout", "{}", &err_msg);
152
153        LlamaCoreError::Operation(err_msg)
154    })?;
155    let answer = model_answer.trim();
156
157    // retrieve the number of prompt and completion tokens
158    let token_info = get_token_info_by_graph(graph)?;
159
160    #[cfg(feature = "logging")]
161    info!(target: "stdout", "Prompt tokens: {}, Completion tokens: {}", token_info.prompt_tokens, token_info.completion_tokens);
162
163    let created = SystemTime::now()
164        .duration_since(std::time::UNIX_EPOCH)
165        .map_err(|e| {
166            let err_msg = format!("Failed to get the current time. {}", e);
167
168            #[cfg(feature = "logging")]
169            error!(target: "stdout", "{}", &err_msg);
170
171            LlamaCoreError::Operation(err_msg)
172        })?;
173
174    #[cfg(feature = "logging")]
175    info!(target: "stdout", "Completions generated successfully.");
176
177    Ok(CompletionObject {
178        id: uuid::Uuid::new_v4().to_string(),
179        object: String::from("text_completion"),
180        created: created.as_secs(),
181        model: graph.name().to_string(),
182        choices: vec![CompletionChoice {
183            index: 0,
184            text: String::from(answer),
185            finish_reason: FinishReason::stop,
186            logprobs: None,
187        }],
188        usage: Usage {
189            prompt_tokens: token_info.prompt_tokens,
190            completion_tokens: token_info.completion_tokens,
191            total_tokens: token_info.prompt_tokens + token_info.completion_tokens,
192        },
193    })
194}