llama_core/
graph.rs

1//! Define Graph and GraphBuilder APIs for creating a new computation graph.
2
3use crate::{error::LlamaCoreError, utils::set_tensor_data_u8, BaseMetadata};
4use wasmedge_wasi_nn::{
5    Error as WasiNnError, Graph as WasiNnGraph, GraphExecutionContext, TensorType,
6};
7
8/// Builder for creating a new computation graph.
9#[derive(Debug)]
10pub struct GraphBuilder<M: BaseMetadata + serde::Serialize + Clone + Default> {
11    metadata: Option<M>,
12    wasi_nn_graph_builder: wasmedge_wasi_nn::GraphBuilder,
13}
14impl<M: BaseMetadata + serde::Serialize + Clone + Default> GraphBuilder<M> {
15    /// Create a new computation graph builder.
16    pub fn new(ty: EngineType) -> Result<Self, LlamaCoreError> {
17        let encoding = match ty {
18            EngineType::Ggml => wasmedge_wasi_nn::GraphEncoding::Ggml,
19            EngineType::Whisper => wasmedge_wasi_nn::GraphEncoding::Whisper,
20            EngineType::Piper => wasmedge_wasi_nn::GraphEncoding::Piper,
21        };
22
23        let wasi_nn_graph_builder =
24            wasmedge_wasi_nn::GraphBuilder::new(encoding, wasmedge_wasi_nn::ExecutionTarget::AUTO);
25
26        Ok(Self {
27            metadata: None,
28            wasi_nn_graph_builder,
29        })
30    }
31
32    pub fn with_config(mut self, metadata: M) -> Result<Self, LlamaCoreError> {
33        let config = serde_json::to_string(&metadata).map_err(|e| {
34            let err_msg = e.to_string();
35
36            #[cfg(feature = "logging")]
37            error!(target: "stdout", "{}", &err_msg);
38
39            LlamaCoreError::Operation(err_msg)
40        })?;
41        self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.config(config);
42        self.metadata = Some(metadata.clone());
43
44        Ok(self)
45    }
46
47    pub fn use_cpu(mut self) -> Self {
48        self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.cpu();
49        self
50    }
51
52    pub fn use_gpu(mut self) -> Self {
53        self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.gpu();
54        self
55    }
56
57    pub fn use_tpu(mut self) -> Self {
58        self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.tpu();
59        self
60    }
61
62    pub fn build_from_buffer<B>(
63        self,
64        bytes_array: impl AsRef<[B]>,
65    ) -> Result<Graph<M>, LlamaCoreError>
66    where
67        B: AsRef<[u8]>,
68    {
69        // load the model
70        let graph = self
71            .wasi_nn_graph_builder
72            .build_from_bytes(bytes_array)
73            .map_err(|e| {
74                let err_msg = e.to_string();
75
76                #[cfg(feature = "logging")]
77                error!(target: "stdout", "{}", &err_msg);
78
79                LlamaCoreError::Operation(err_msg)
80            })?;
81
82        // initialize the execution context
83        let context = graph.init_execution_context().map_err(|e| {
84            let err_msg = e.to_string();
85
86            #[cfg(feature = "logging")]
87            error!(target: "stdout", "{}", &err_msg);
88
89            LlamaCoreError::Operation(err_msg)
90        })?;
91
92        let created = std::time::SystemTime::now()
93            .duration_since(std::time::UNIX_EPOCH)
94            .map_err(|e| {
95                let err_msg = e.to_string();
96
97                #[cfg(feature = "logging")]
98                error!(target: "stdout", "{}", &err_msg);
99
100                LlamaCoreError::Operation(err_msg)
101            })?;
102
103        Ok(Graph {
104            created,
105            metadata: self.metadata.clone().unwrap_or_default(),
106            graph,
107            context,
108        })
109    }
110
111    pub fn build_from_files<P>(self, files: impl AsRef<[P]>) -> Result<Graph<M>, LlamaCoreError>
112    where
113        P: AsRef<std::path::Path>,
114    {
115        // load the model
116        let graph = self
117            .wasi_nn_graph_builder
118            .build_from_files(files)
119            .map_err(|e| {
120                let err_msg = e.to_string();
121
122                #[cfg(feature = "logging")]
123                error!(target: "stdout", "{}", &err_msg);
124
125                LlamaCoreError::Operation(err_msg)
126            })?;
127
128        // initialize the execution context
129        let context = graph.init_execution_context().map_err(|e| {
130            let err_msg = e.to_string();
131
132            #[cfg(feature = "logging")]
133            error!(target: "stdout", "{}", &err_msg);
134
135            LlamaCoreError::Operation(err_msg)
136        })?;
137
138        let created = std::time::SystemTime::now()
139            .duration_since(std::time::UNIX_EPOCH)
140            .map_err(|e| {
141                let err_msg = e.to_string();
142
143                #[cfg(feature = "logging")]
144                error!(target: "stdout", "{}", &err_msg);
145
146                LlamaCoreError::Operation(err_msg)
147            })?;
148
149        Ok(Graph {
150            created,
151            metadata: self.metadata.clone().unwrap_or_default(),
152            graph,
153            context,
154        })
155    }
156
157    pub fn build_from_cache(self) -> Result<Graph<M>, LlamaCoreError> {
158        match &self.metadata {
159            Some(metadata) => {
160                // load the model
161                let graph = self
162                    .wasi_nn_graph_builder
163                    .build_from_cache(metadata.model_alias())
164                    .map_err(|e| {
165                        let err_msg = e.to_string();
166
167                        #[cfg(feature = "logging")]
168                        error!(target: "stdout", "{}", &err_msg);
169
170                        LlamaCoreError::Operation(err_msg)
171                    })?;
172
173                // initialize the execution context
174                let context = graph.init_execution_context().map_err(|e| {
175                    let err_msg = e.to_string();
176
177                    #[cfg(feature = "logging")]
178                    error!(target: "stdout", "{}", &err_msg);
179
180                    LlamaCoreError::Operation(err_msg)
181                })?;
182
183                let created = std::time::SystemTime::now()
184                    .duration_since(std::time::UNIX_EPOCH)
185                    .map_err(|e| {
186                        let err_msg = e.to_string();
187
188                        #[cfg(feature = "logging")]
189                        error!(target: "stdout", "{}", &err_msg);
190
191                        LlamaCoreError::Operation(err_msg)
192                    })?;
193
194                Ok(Graph {
195                    created,
196                    metadata: metadata.clone(),
197                    graph,
198                    context,
199                })
200            }
201            None => {
202                let err_msg =
203                    "Failed to create a Graph from cache. Reason: Metadata is not provided."
204                        .to_string();
205
206                #[cfg(feature = "logging")]
207                error!(target: "stdout", "{}", &err_msg);
208
209                Err(LlamaCoreError::Operation(err_msg))
210            }
211        }
212    }
213}
214
215/// Wrapper of the `wasmedge_wasi_nn::Graph` struct
216#[derive(Debug)]
217pub struct Graph<M: BaseMetadata + serde::Serialize + Clone + Default> {
218    pub created: std::time::Duration,
219    pub metadata: M,
220    graph: WasiNnGraph,
221    context: GraphExecutionContext,
222}
223impl<M: BaseMetadata + serde::Serialize + Clone + Default> Graph<M> {
224    /// Create a new computation graph from the given metadata.
225    pub fn new(metadata: M) -> Result<Self, LlamaCoreError> {
226        let config = serde_json::to_string(&metadata).map_err(|e| {
227            let err_msg = e.to_string();
228
229            #[cfg(feature = "logging")]
230            error!(target: "stdout", "{}", &err_msg);
231
232            LlamaCoreError::Operation(err_msg)
233        })?;
234
235        // load the model
236        let graph = wasmedge_wasi_nn::GraphBuilder::new(
237            wasmedge_wasi_nn::GraphEncoding::Ggml,
238            wasmedge_wasi_nn::ExecutionTarget::AUTO,
239        )
240        .config(config)
241        .build_from_cache(metadata.model_alias())
242        .map_err(|e| {
243            let err_msg = e.to_string();
244
245            #[cfg(feature = "logging")]
246            error!(target: "stdout", "{}", &err_msg);
247
248            LlamaCoreError::Operation(err_msg)
249        })?;
250
251        // initialize the execution context
252        let context = graph.init_execution_context().map_err(|e| {
253            let err_msg = e.to_string();
254
255            #[cfg(feature = "logging")]
256            error!(target: "stdout", "{}", &err_msg);
257
258            LlamaCoreError::Operation(err_msg)
259        })?;
260
261        let created = std::time::SystemTime::now()
262            .duration_since(std::time::UNIX_EPOCH)
263            .map_err(|e| {
264                let err_msg = e.to_string();
265
266                #[cfg(feature = "logging")]
267                error!(target: "stdout", "{}", &err_msg);
268
269                LlamaCoreError::Operation(err_msg)
270            })?;
271
272        Ok(Self {
273            created,
274            metadata: metadata.clone(),
275            graph,
276            context,
277        })
278    }
279
280    /// Get the name of the model
281    pub fn name(&self) -> &str {
282        self.metadata.model_name()
283    }
284
285    /// Get the alias of the model
286    pub fn alias(&self) -> &str {
287        self.metadata.model_alias()
288    }
289
290    /// Update metadata
291    pub fn update_metadata(&mut self) -> Result<(), LlamaCoreError> {
292        #[cfg(feature = "logging")]
293        info!(target: "stdout", "Update metadata for the model named {}", self.name());
294
295        // update metadata
296        let config = match serde_json::to_string(&self.metadata) {
297            Ok(config) => config,
298            Err(e) => {
299                let err_msg = format!("Failed to update metadta. Reason: Fail to serialize metadata to a JSON string. {}", e);
300
301                #[cfg(feature = "logging")]
302                error!(target: "stdout", "{}", &err_msg);
303
304                return Err(LlamaCoreError::Operation(err_msg));
305            }
306        };
307
308        let res = set_tensor_data_u8(self, 1, config.as_bytes());
309
310        #[cfg(feature = "logging")]
311        info!(target: "stdout", "Metadata updated successfully.");
312
313        res
314    }
315
316    /// Set input uses the data, not only [u8](https://doc.rust-lang.org/nightly/std/primitive.u8.html), but also [f32](https://doc.rust-lang.org/nightly/std/primitive.f32.html), [i32](https://doc.rust-lang.org/nightly/std/primitive.i32.html), etc.
317    pub fn set_input<T: Sized>(
318        &mut self,
319        index: usize,
320        tensor_type: TensorType,
321        dimensions: &[usize],
322        data: impl AsRef<[T]>,
323    ) -> Result<(), WasiNnError> {
324        self.context.set_input(index, tensor_type, dimensions, data)
325    }
326
327    /// Compute the inference on the given inputs.
328    pub fn compute(&mut self) -> Result<(), WasiNnError> {
329        self.context.compute()
330    }
331
332    /// Compute the inference on the given inputs.
333    ///
334    /// Note that this method is used for the stream mode. It generates one token at a time.
335    pub fn compute_single(&mut self) -> Result<(), WasiNnError> {
336        self.context.compute_single()
337    }
338
339    /// Copy output tensor to out_buffer, return the output’s **size in bytes**.
340    pub fn get_output<T: Sized>(
341        &self,
342        index: usize,
343        out_buffer: &mut [T],
344    ) -> Result<usize, WasiNnError> {
345        self.context.get_output(index, out_buffer)
346    }
347
348    /// Copy output tensor to out_buffer, return the output’s **size in bytes**.
349    ///
350    /// Note that this method is used for the stream mode. It returns one token at a time.
351    pub fn get_output_single<T: Sized>(
352        &self,
353        index: usize,
354        out_buffer: &mut [T],
355    ) -> Result<usize, WasiNnError> {
356        self.context.get_output_single(index, out_buffer)
357    }
358
359    /// Clear the computation context.
360    ///
361    /// Note that this method is used for the stream mode. It clears the context after the stream mode is finished.
362    pub fn finish_single(&mut self) -> Result<(), WasiNnError> {
363        self.context.fini_single()
364    }
365}
366impl<M: BaseMetadata + serde::Serialize + Clone + Default> Drop for Graph<M> {
367    fn drop(&mut self) {
368        // unload the wasi-nn graph
369        if let Err(e) = self.graph.unload() {
370            let err_msg = format!("Failed to unload the wasi-nn graph. Reason: {}", e);
371
372            #[cfg(feature = "logging")]
373            error!(target: "stdout", "{}", err_msg);
374
375            #[cfg(not(feature = "logging"))]
376            eprintln!("{}", err_msg);
377        }
378    }
379}
380
381/// Engine type
382#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
383pub enum EngineType {
384    Ggml,
385    Whisper,
386    Piper,
387}