llama_core/
graph.rs

1//! Define Graph and GraphBuilder APIs for creating a new computation graph.
2
3use crate::{error::LlamaCoreError, utils::set_tensor_data_u8, BaseMetadata};
4use wasmedge_wasi_nn::{
5    Error as WasiNnError, Graph as WasiNnGraph, GraphExecutionContext, TensorType,
6};
7
8/// Builder for creating a new computation graph.
9#[derive(Debug)]
10pub struct GraphBuilder<M: BaseMetadata + serde::Serialize + Clone + Default> {
11    metadata: Option<M>,
12    wasi_nn_graph_builder: wasmedge_wasi_nn::GraphBuilder,
13}
14impl<M: BaseMetadata + serde::Serialize + Clone + Default> GraphBuilder<M> {
15    /// Create a new computation graph builder.
16    pub fn new(ty: EngineType) -> Result<Self, LlamaCoreError> {
17        let encoding = match ty {
18            EngineType::Ggml => wasmedge_wasi_nn::GraphEncoding::Ggml,
19            EngineType::Whisper => wasmedge_wasi_nn::GraphEncoding::Whisper,
20            EngineType::Piper => wasmedge_wasi_nn::GraphEncoding::Piper,
21        };
22
23        let wasi_nn_graph_builder =
24            wasmedge_wasi_nn::GraphBuilder::new(encoding, wasmedge_wasi_nn::ExecutionTarget::AUTO);
25
26        Ok(Self {
27            metadata: None,
28            wasi_nn_graph_builder,
29        })
30    }
31
32    pub fn with_config(mut self, metadata: M) -> Result<Self, LlamaCoreError> {
33        let config = serde_json::to_string(&metadata).map_err(|e| {
34            let err_msg = e.to_string();
35
36            #[cfg(feature = "logging")]
37            error!(target: "stdout", "{}", &err_msg);
38
39            LlamaCoreError::Operation(err_msg)
40        })?;
41        self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.config(config);
42        self.metadata = Some(metadata.clone());
43
44        Ok(self)
45    }
46
47    pub fn use_cpu(mut self) -> Self {
48        self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.cpu();
49        self
50    }
51
52    pub fn use_gpu(mut self) -> Self {
53        self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.gpu();
54        self
55    }
56
57    pub fn use_tpu(mut self) -> Self {
58        self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.tpu();
59        self
60    }
61
62    pub fn build_from_buffer<B>(
63        self,
64        bytes_array: impl AsRef<[B]>,
65    ) -> Result<Graph<M>, LlamaCoreError>
66    where
67        B: AsRef<[u8]>,
68    {
69        // load the model
70        let graph = self
71            .wasi_nn_graph_builder
72            .build_from_bytes(bytes_array)
73            .map_err(|e| {
74                let err_msg = e.to_string();
75
76                #[cfg(feature = "logging")]
77                error!(target: "stdout", "{}", &err_msg);
78
79                LlamaCoreError::Operation(err_msg)
80            })?;
81
82        // initialize the execution context
83        let context = graph.init_execution_context().map_err(|e| {
84            let err_msg = e.to_string();
85
86            #[cfg(feature = "logging")]
87            error!(target: "stdout", "{}", &err_msg);
88
89            LlamaCoreError::Operation(err_msg)
90        })?;
91
92        let created = std::time::SystemTime::now();
93
94        Ok(Graph {
95            created,
96            metadata: self.metadata.clone().unwrap_or_default(),
97            graph,
98            context,
99        })
100    }
101
102    pub fn build_from_files<P>(self, files: impl AsRef<[P]>) -> Result<Graph<M>, LlamaCoreError>
103    where
104        P: AsRef<std::path::Path>,
105    {
106        // load the model
107        let graph = self
108            .wasi_nn_graph_builder
109            .build_from_files(files)
110            .map_err(|e| {
111                let err_msg = e.to_string();
112
113                #[cfg(feature = "logging")]
114                error!(target: "stdout", "{}", &err_msg);
115
116                LlamaCoreError::Operation(err_msg)
117            })?;
118
119        // initialize the execution context
120        let context = graph.init_execution_context().map_err(|e| {
121            let err_msg = e.to_string();
122
123            #[cfg(feature = "logging")]
124            error!(target: "stdout", "{}", &err_msg);
125
126            LlamaCoreError::Operation(err_msg)
127        })?;
128
129        let created = std::time::SystemTime::now();
130
131        Ok(Graph {
132            created,
133            metadata: self.metadata.clone().unwrap_or_default(),
134            graph,
135            context,
136        })
137    }
138
139    pub fn build_from_cache(self) -> Result<Graph<M>, LlamaCoreError> {
140        match &self.metadata {
141            Some(metadata) => {
142                // load the model
143                let graph = self
144                    .wasi_nn_graph_builder
145                    .build_from_cache(metadata.model_alias())
146                    .map_err(|e| {
147                        let err_msg = e.to_string();
148
149                        #[cfg(feature = "logging")]
150                        error!(target: "stdout", "{}", &err_msg);
151
152                        LlamaCoreError::Operation(err_msg)
153                    })?;
154
155                // initialize the execution context
156                let context = graph.init_execution_context().map_err(|e| {
157                    let err_msg = e.to_string();
158
159                    #[cfg(feature = "logging")]
160                    error!(target: "stdout", "{}", &err_msg);
161
162                    LlamaCoreError::Operation(err_msg)
163                })?;
164
165                let created = std::time::SystemTime::now();
166
167                Ok(Graph {
168                    created,
169                    metadata: metadata.clone(),
170                    graph,
171                    context,
172                })
173            }
174            None => {
175                let err_msg =
176                    "Failed to create a Graph from cache. Reason: Metadata is not provided."
177                        .to_string();
178
179                #[cfg(feature = "logging")]
180                error!(target: "stdout", "{}", &err_msg);
181
182                Err(LlamaCoreError::Operation(err_msg))
183            }
184        }
185    }
186}
187
188/// Wrapper of the `wasmedge_wasi_nn::Graph` struct
189#[derive(Debug)]
190pub struct Graph<M: BaseMetadata + serde::Serialize + Clone + Default> {
191    pub created: std::time::SystemTime,
192    pub metadata: M,
193    graph: WasiNnGraph,
194    context: GraphExecutionContext,
195}
196impl<M: BaseMetadata + serde::Serialize + Clone + Default> Graph<M> {
197    /// Create a new computation graph from the given metadata.
198    pub fn new(metadata: M) -> Result<Self, LlamaCoreError> {
199        let config = serde_json::to_string(&metadata).map_err(|e| {
200            let err_msg = e.to_string();
201
202            #[cfg(feature = "logging")]
203            error!(target: "stdout", "{}", &err_msg);
204
205            LlamaCoreError::Operation(err_msg)
206        })?;
207
208        // load the model
209        let graph = wasmedge_wasi_nn::GraphBuilder::new(
210            wasmedge_wasi_nn::GraphEncoding::Ggml,
211            wasmedge_wasi_nn::ExecutionTarget::AUTO,
212        )
213        .config(config)
214        .build_from_cache(metadata.model_alias())
215        .map_err(|e| {
216            let err_msg = e.to_string();
217
218            #[cfg(feature = "logging")]
219            error!(target: "stdout", "{}", &err_msg);
220
221            LlamaCoreError::Operation(err_msg)
222        })?;
223
224        // initialize the execution context
225        let context = graph.init_execution_context().map_err(|e| {
226            let err_msg = e.to_string();
227
228            #[cfg(feature = "logging")]
229            error!(target: "stdout", "{}", &err_msg);
230
231            LlamaCoreError::Operation(err_msg)
232        })?;
233
234        let created = std::time::SystemTime::now();
235
236        Ok(Self {
237            created,
238            metadata: metadata.clone(),
239            graph,
240            context,
241        })
242    }
243
244    /// Get the name of the model
245    pub fn name(&self) -> &str {
246        self.metadata.model_name()
247    }
248
249    /// Get the alias of the model
250    pub fn alias(&self) -> &str {
251        self.metadata.model_alias()
252    }
253
254    /// Update metadata
255    pub fn update_metadata(&mut self) -> Result<(), LlamaCoreError> {
256        #[cfg(feature = "logging")]
257        info!(target: "stdout", "Update metadata for the model named {}", self.name());
258
259        // update metadata
260        let config = match serde_json::to_string(&self.metadata) {
261            Ok(config) => config,
262            Err(e) => {
263                let err_msg = format!("Failed to update metadta. Reason: Fail to serialize metadata to a JSON string. {e}");
264
265                #[cfg(feature = "logging")]
266                error!(target: "stdout", "{}", &err_msg);
267
268                return Err(LlamaCoreError::Operation(err_msg));
269            }
270        };
271
272        let res = set_tensor_data_u8(self, 1, config.as_bytes());
273
274        #[cfg(feature = "logging")]
275        info!(target: "stdout", "Metadata updated successfully.");
276
277        res
278    }
279
280    /// Set input uses the data, not only [u8](https://doc.rust-lang.org/nightly/std/primitive.u8.html), but also [f32](https://doc.rust-lang.org/nightly/std/primitive.f32.html), [i32](https://doc.rust-lang.org/nightly/std/primitive.i32.html), etc.
281    pub fn set_input<T: Sized>(
282        &mut self,
283        index: usize,
284        tensor_type: TensorType,
285        dimensions: &[usize],
286        data: impl AsRef<[T]>,
287    ) -> Result<(), WasiNnError> {
288        self.context.set_input(index, tensor_type, dimensions, data)
289    }
290
291    /// Compute the inference on the given inputs.
292    pub fn compute(&mut self) -> Result<(), WasiNnError> {
293        self.context.compute()
294    }
295
296    /// Compute the inference on the given inputs.
297    ///
298    /// Note that this method is used for the stream mode. It generates one token at a time.
299    pub fn compute_single(&mut self) -> Result<(), WasiNnError> {
300        self.context.compute_single()
301    }
302
303    /// Copy output tensor to out_buffer, return the output’s **size in bytes**.
304    pub fn get_output<T: Sized>(
305        &self,
306        index: usize,
307        out_buffer: &mut [T],
308    ) -> Result<usize, WasiNnError> {
309        self.context.get_output(index, out_buffer)
310    }
311
312    /// Copy output tensor to out_buffer, return the output’s **size in bytes**.
313    ///
314    /// Note that this method is used for the stream mode. It returns one token at a time.
315    pub fn get_output_single<T: Sized>(
316        &self,
317        index: usize,
318        out_buffer: &mut [T],
319    ) -> Result<usize, WasiNnError> {
320        self.context.get_output_single(index, out_buffer)
321    }
322
323    /// Clear the computation context.
324    ///
325    /// Note that this method is used for the stream mode. It clears the context after the stream mode is finished.
326    pub fn finish_single(&mut self) -> Result<(), WasiNnError> {
327        self.context.fini_single()
328    }
329}
330impl<M: BaseMetadata + serde::Serialize + Clone + Default> Drop for Graph<M> {
331    fn drop(&mut self) {
332        // unload the wasi-nn graph
333        if let Err(e) = self.graph.unload() {
334            let err_msg = format!("Failed to unload the wasi-nn graph. Reason: {e}");
335
336            #[cfg(feature = "logging")]
337            error!(target: "stdout", "{err_msg}");
338
339            #[cfg(not(feature = "logging"))]
340            eprintln!("{}", err_msg);
341        }
342    }
343}
344
345/// Engine type
346#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
347pub enum EngineType {
348    Ggml,
349    Whisper,
350    Piper,
351}