1use crate::{error::LlamaCoreError, utils::set_tensor_data_u8, BaseMetadata};
4use wasmedge_wasi_nn::{
5 Error as WasiNnError, Graph as WasiNnGraph, GraphExecutionContext, TensorType,
6};
7
8#[derive(Debug)]
10pub struct GraphBuilder<M: BaseMetadata + serde::Serialize + Clone + Default> {
11 metadata: Option<M>,
12 wasi_nn_graph_builder: wasmedge_wasi_nn::GraphBuilder,
13}
14impl<M: BaseMetadata + serde::Serialize + Clone + Default> GraphBuilder<M> {
15 pub fn new(ty: EngineType) -> Result<Self, LlamaCoreError> {
17 let encoding = match ty {
18 EngineType::Ggml => wasmedge_wasi_nn::GraphEncoding::Ggml,
19 EngineType::Whisper => wasmedge_wasi_nn::GraphEncoding::Whisper,
20 EngineType::Piper => wasmedge_wasi_nn::GraphEncoding::Piper,
21 };
22
23 let wasi_nn_graph_builder =
24 wasmedge_wasi_nn::GraphBuilder::new(encoding, wasmedge_wasi_nn::ExecutionTarget::AUTO);
25
26 Ok(Self {
27 metadata: None,
28 wasi_nn_graph_builder,
29 })
30 }
31
32 pub fn with_config(mut self, metadata: M) -> Result<Self, LlamaCoreError> {
33 let config = serde_json::to_string(&metadata).map_err(|e| {
34 let err_msg = e.to_string();
35
36 #[cfg(feature = "logging")]
37 error!(target: "stdout", "{}", &err_msg);
38
39 LlamaCoreError::Operation(err_msg)
40 })?;
41 self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.config(config);
42 self.metadata = Some(metadata.clone());
43
44 Ok(self)
45 }
46
47 pub fn use_cpu(mut self) -> Self {
48 self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.cpu();
49 self
50 }
51
52 pub fn use_gpu(mut self) -> Self {
53 self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.gpu();
54 self
55 }
56
57 pub fn use_tpu(mut self) -> Self {
58 self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.tpu();
59 self
60 }
61
62 pub fn build_from_buffer<B>(
63 self,
64 bytes_array: impl AsRef<[B]>,
65 ) -> Result<Graph<M>, LlamaCoreError>
66 where
67 B: AsRef<[u8]>,
68 {
69 let graph = self
71 .wasi_nn_graph_builder
72 .build_from_bytes(bytes_array)
73 .map_err(|e| {
74 let err_msg = e.to_string();
75
76 #[cfg(feature = "logging")]
77 error!(target: "stdout", "{}", &err_msg);
78
79 LlamaCoreError::Operation(err_msg)
80 })?;
81
82 let context = graph.init_execution_context().map_err(|e| {
84 let err_msg = e.to_string();
85
86 #[cfg(feature = "logging")]
87 error!(target: "stdout", "{}", &err_msg);
88
89 LlamaCoreError::Operation(err_msg)
90 })?;
91
92 let created = std::time::SystemTime::now();
93
94 Ok(Graph {
95 created,
96 metadata: self.metadata.clone().unwrap_or_default(),
97 graph,
98 context,
99 })
100 }
101
102 pub fn build_from_files<P>(self, files: impl AsRef<[P]>) -> Result<Graph<M>, LlamaCoreError>
103 where
104 P: AsRef<std::path::Path>,
105 {
106 let graph = self
108 .wasi_nn_graph_builder
109 .build_from_files(files)
110 .map_err(|e| {
111 let err_msg = e.to_string();
112
113 #[cfg(feature = "logging")]
114 error!(target: "stdout", "{}", &err_msg);
115
116 LlamaCoreError::Operation(err_msg)
117 })?;
118
119 let context = graph.init_execution_context().map_err(|e| {
121 let err_msg = e.to_string();
122
123 #[cfg(feature = "logging")]
124 error!(target: "stdout", "{}", &err_msg);
125
126 LlamaCoreError::Operation(err_msg)
127 })?;
128
129 let created = std::time::SystemTime::now();
130
131 Ok(Graph {
132 created,
133 metadata: self.metadata.clone().unwrap_or_default(),
134 graph,
135 context,
136 })
137 }
138
139 pub fn build_from_cache(self) -> Result<Graph<M>, LlamaCoreError> {
140 match &self.metadata {
141 Some(metadata) => {
142 let graph = self
144 .wasi_nn_graph_builder
145 .build_from_cache(metadata.model_alias())
146 .map_err(|e| {
147 let err_msg = e.to_string();
148
149 #[cfg(feature = "logging")]
150 error!(target: "stdout", "{}", &err_msg);
151
152 LlamaCoreError::Operation(err_msg)
153 })?;
154
155 let context = graph.init_execution_context().map_err(|e| {
157 let err_msg = e.to_string();
158
159 #[cfg(feature = "logging")]
160 error!(target: "stdout", "{}", &err_msg);
161
162 LlamaCoreError::Operation(err_msg)
163 })?;
164
165 let created = std::time::SystemTime::now();
166
167 Ok(Graph {
168 created,
169 metadata: metadata.clone(),
170 graph,
171 context,
172 })
173 }
174 None => {
175 let err_msg =
176 "Failed to create a Graph from cache. Reason: Metadata is not provided."
177 .to_string();
178
179 #[cfg(feature = "logging")]
180 error!(target: "stdout", "{}", &err_msg);
181
182 Err(LlamaCoreError::Operation(err_msg))
183 }
184 }
185 }
186}
187
188#[derive(Debug)]
190pub struct Graph<M: BaseMetadata + serde::Serialize + Clone + Default> {
191 pub created: std::time::SystemTime,
192 pub metadata: M,
193 graph: WasiNnGraph,
194 context: GraphExecutionContext,
195}
196impl<M: BaseMetadata + serde::Serialize + Clone + Default> Graph<M> {
197 pub fn new(metadata: M) -> Result<Self, LlamaCoreError> {
199 let config = serde_json::to_string(&metadata).map_err(|e| {
200 let err_msg = e.to_string();
201
202 #[cfg(feature = "logging")]
203 error!(target: "stdout", "{}", &err_msg);
204
205 LlamaCoreError::Operation(err_msg)
206 })?;
207
208 let graph = wasmedge_wasi_nn::GraphBuilder::new(
210 wasmedge_wasi_nn::GraphEncoding::Ggml,
211 wasmedge_wasi_nn::ExecutionTarget::AUTO,
212 )
213 .config(config)
214 .build_from_cache(metadata.model_alias())
215 .map_err(|e| {
216 let err_msg = e.to_string();
217
218 #[cfg(feature = "logging")]
219 error!(target: "stdout", "{}", &err_msg);
220
221 LlamaCoreError::Operation(err_msg)
222 })?;
223
224 let context = graph.init_execution_context().map_err(|e| {
226 let err_msg = e.to_string();
227
228 #[cfg(feature = "logging")]
229 error!(target: "stdout", "{}", &err_msg);
230
231 LlamaCoreError::Operation(err_msg)
232 })?;
233
234 let created = std::time::SystemTime::now();
235
236 Ok(Self {
237 created,
238 metadata: metadata.clone(),
239 graph,
240 context,
241 })
242 }
243
244 pub fn name(&self) -> &str {
246 self.metadata.model_name()
247 }
248
249 pub fn alias(&self) -> &str {
251 self.metadata.model_alias()
252 }
253
254 pub fn update_metadata(&mut self) -> Result<(), LlamaCoreError> {
256 #[cfg(feature = "logging")]
257 info!(target: "stdout", "Update metadata for the model named {}", self.name());
258
259 let config = match serde_json::to_string(&self.metadata) {
261 Ok(config) => config,
262 Err(e) => {
263 let err_msg = format!("Failed to update metadta. Reason: Fail to serialize metadata to a JSON string. {e}");
264
265 #[cfg(feature = "logging")]
266 error!(target: "stdout", "{}", &err_msg);
267
268 return Err(LlamaCoreError::Operation(err_msg));
269 }
270 };
271
272 let res = set_tensor_data_u8(self, 1, config.as_bytes());
273
274 #[cfg(feature = "logging")]
275 info!(target: "stdout", "Metadata updated successfully.");
276
277 res
278 }
279
280 pub fn set_input<T: Sized>(
282 &mut self,
283 index: usize,
284 tensor_type: TensorType,
285 dimensions: &[usize],
286 data: impl AsRef<[T]>,
287 ) -> Result<(), WasiNnError> {
288 self.context.set_input(index, tensor_type, dimensions, data)
289 }
290
291 pub fn compute(&mut self) -> Result<(), WasiNnError> {
293 self.context.compute()
294 }
295
296 pub fn compute_single(&mut self) -> Result<(), WasiNnError> {
300 self.context.compute_single()
301 }
302
303 pub fn get_output<T: Sized>(
305 &self,
306 index: usize,
307 out_buffer: &mut [T],
308 ) -> Result<usize, WasiNnError> {
309 self.context.get_output(index, out_buffer)
310 }
311
312 pub fn get_output_single<T: Sized>(
316 &self,
317 index: usize,
318 out_buffer: &mut [T],
319 ) -> Result<usize, WasiNnError> {
320 self.context.get_output_single(index, out_buffer)
321 }
322
323 pub fn finish_single(&mut self) -> Result<(), WasiNnError> {
327 self.context.fini_single()
328 }
329}
330impl<M: BaseMetadata + serde::Serialize + Clone + Default> Drop for Graph<M> {
331 fn drop(&mut self) {
332 if let Err(e) = self.graph.unload() {
334 let err_msg = format!("Failed to unload the wasi-nn graph. Reason: {e}");
335
336 #[cfg(feature = "logging")]
337 error!(target: "stdout", "{err_msg}");
338
339 #[cfg(not(feature = "logging"))]
340 eprintln!("{}", err_msg);
341 }
342 }
343}
344
345#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
347pub enum EngineType {
348 Ggml,
349 Whisper,
350 Piper,
351}