1use crate::{error::LlamaCoreError, utils::set_tensor_data_u8, BaseMetadata};
4use wasmedge_wasi_nn::{
5 Error as WasiNnError, Graph as WasiNnGraph, GraphExecutionContext, TensorType,
6};
7
8#[derive(Debug)]
10pub struct GraphBuilder<M: BaseMetadata + serde::Serialize + Clone + Default> {
11 metadata: Option<M>,
12 wasi_nn_graph_builder: wasmedge_wasi_nn::GraphBuilder,
13}
14impl<M: BaseMetadata + serde::Serialize + Clone + Default> GraphBuilder<M> {
15 pub fn new(ty: EngineType) -> Result<Self, LlamaCoreError> {
17 let encoding = match ty {
18 EngineType::Ggml => wasmedge_wasi_nn::GraphEncoding::Ggml,
19 EngineType::Whisper => wasmedge_wasi_nn::GraphEncoding::Whisper,
20 EngineType::Piper => wasmedge_wasi_nn::GraphEncoding::Piper,
21 };
22
23 let wasi_nn_graph_builder =
24 wasmedge_wasi_nn::GraphBuilder::new(encoding, wasmedge_wasi_nn::ExecutionTarget::AUTO);
25
26 Ok(Self {
27 metadata: None,
28 wasi_nn_graph_builder,
29 })
30 }
31
32 pub fn with_config(mut self, metadata: M) -> Result<Self, LlamaCoreError> {
33 let config = serde_json::to_string(&metadata).map_err(|e| {
34 let err_msg = e.to_string();
35
36 #[cfg(feature = "logging")]
37 error!(target: "stdout", "{}", &err_msg);
38
39 LlamaCoreError::Operation(err_msg)
40 })?;
41 self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.config(config);
42 self.metadata = Some(metadata.clone());
43
44 Ok(self)
45 }
46
47 pub fn use_cpu(mut self) -> Self {
48 self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.cpu();
49 self
50 }
51
52 pub fn use_gpu(mut self) -> Self {
53 self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.gpu();
54 self
55 }
56
57 pub fn use_tpu(mut self) -> Self {
58 self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.tpu();
59 self
60 }
61
62 pub fn build_from_buffer<B>(
63 self,
64 bytes_array: impl AsRef<[B]>,
65 ) -> Result<Graph<M>, LlamaCoreError>
66 where
67 B: AsRef<[u8]>,
68 {
69 let graph = self
71 .wasi_nn_graph_builder
72 .build_from_bytes(bytes_array)
73 .map_err(|e| {
74 let err_msg = e.to_string();
75
76 #[cfg(feature = "logging")]
77 error!(target: "stdout", "{}", &err_msg);
78
79 LlamaCoreError::Operation(err_msg)
80 })?;
81
82 let context = graph.init_execution_context().map_err(|e| {
84 let err_msg = e.to_string();
85
86 #[cfg(feature = "logging")]
87 error!(target: "stdout", "{}", &err_msg);
88
89 LlamaCoreError::Operation(err_msg)
90 })?;
91
92 let created = std::time::SystemTime::now()
93 .duration_since(std::time::UNIX_EPOCH)
94 .map_err(|e| {
95 let err_msg = e.to_string();
96
97 #[cfg(feature = "logging")]
98 error!(target: "stdout", "{}", &err_msg);
99
100 LlamaCoreError::Operation(err_msg)
101 })?;
102
103 Ok(Graph {
104 created,
105 metadata: self.metadata.clone().unwrap_or_default(),
106 graph,
107 context,
108 })
109 }
110
111 pub fn build_from_files<P>(self, files: impl AsRef<[P]>) -> Result<Graph<M>, LlamaCoreError>
112 where
113 P: AsRef<std::path::Path>,
114 {
115 let graph = self
117 .wasi_nn_graph_builder
118 .build_from_files(files)
119 .map_err(|e| {
120 let err_msg = e.to_string();
121
122 #[cfg(feature = "logging")]
123 error!(target: "stdout", "{}", &err_msg);
124
125 LlamaCoreError::Operation(err_msg)
126 })?;
127
128 let context = graph.init_execution_context().map_err(|e| {
130 let err_msg = e.to_string();
131
132 #[cfg(feature = "logging")]
133 error!(target: "stdout", "{}", &err_msg);
134
135 LlamaCoreError::Operation(err_msg)
136 })?;
137
138 let created = std::time::SystemTime::now()
139 .duration_since(std::time::UNIX_EPOCH)
140 .map_err(|e| {
141 let err_msg = e.to_string();
142
143 #[cfg(feature = "logging")]
144 error!(target: "stdout", "{}", &err_msg);
145
146 LlamaCoreError::Operation(err_msg)
147 })?;
148
149 Ok(Graph {
150 created,
151 metadata: self.metadata.clone().unwrap_or_default(),
152 graph,
153 context,
154 })
155 }
156
157 pub fn build_from_cache(self) -> Result<Graph<M>, LlamaCoreError> {
158 match &self.metadata {
159 Some(metadata) => {
160 let graph = self
162 .wasi_nn_graph_builder
163 .build_from_cache(metadata.model_alias())
164 .map_err(|e| {
165 let err_msg = e.to_string();
166
167 #[cfg(feature = "logging")]
168 error!(target: "stdout", "{}", &err_msg);
169
170 LlamaCoreError::Operation(err_msg)
171 })?;
172
173 let context = graph.init_execution_context().map_err(|e| {
175 let err_msg = e.to_string();
176
177 #[cfg(feature = "logging")]
178 error!(target: "stdout", "{}", &err_msg);
179
180 LlamaCoreError::Operation(err_msg)
181 })?;
182
183 let created = std::time::SystemTime::now()
184 .duration_since(std::time::UNIX_EPOCH)
185 .map_err(|e| {
186 let err_msg = e.to_string();
187
188 #[cfg(feature = "logging")]
189 error!(target: "stdout", "{}", &err_msg);
190
191 LlamaCoreError::Operation(err_msg)
192 })?;
193
194 Ok(Graph {
195 created,
196 metadata: metadata.clone(),
197 graph,
198 context,
199 })
200 }
201 None => {
202 let err_msg =
203 "Failed to create a Graph from cache. Reason: Metadata is not provided."
204 .to_string();
205
206 #[cfg(feature = "logging")]
207 error!(target: "stdout", "{}", &err_msg);
208
209 Err(LlamaCoreError::Operation(err_msg))
210 }
211 }
212 }
213}
214
215#[derive(Debug)]
217pub struct Graph<M: BaseMetadata + serde::Serialize + Clone + Default> {
218 pub created: std::time::Duration,
219 pub metadata: M,
220 graph: WasiNnGraph,
221 context: GraphExecutionContext,
222}
223impl<M: BaseMetadata + serde::Serialize + Clone + Default> Graph<M> {
224 pub fn new(metadata: M) -> Result<Self, LlamaCoreError> {
226 let config = serde_json::to_string(&metadata).map_err(|e| {
227 let err_msg = e.to_string();
228
229 #[cfg(feature = "logging")]
230 error!(target: "stdout", "{}", &err_msg);
231
232 LlamaCoreError::Operation(err_msg)
233 })?;
234
235 let graph = wasmedge_wasi_nn::GraphBuilder::new(
237 wasmedge_wasi_nn::GraphEncoding::Ggml,
238 wasmedge_wasi_nn::ExecutionTarget::AUTO,
239 )
240 .config(config)
241 .build_from_cache(metadata.model_alias())
242 .map_err(|e| {
243 let err_msg = e.to_string();
244
245 #[cfg(feature = "logging")]
246 error!(target: "stdout", "{}", &err_msg);
247
248 LlamaCoreError::Operation(err_msg)
249 })?;
250
251 let context = graph.init_execution_context().map_err(|e| {
253 let err_msg = e.to_string();
254
255 #[cfg(feature = "logging")]
256 error!(target: "stdout", "{}", &err_msg);
257
258 LlamaCoreError::Operation(err_msg)
259 })?;
260
261 let created = std::time::SystemTime::now()
262 .duration_since(std::time::UNIX_EPOCH)
263 .map_err(|e| {
264 let err_msg = e.to_string();
265
266 #[cfg(feature = "logging")]
267 error!(target: "stdout", "{}", &err_msg);
268
269 LlamaCoreError::Operation(err_msg)
270 })?;
271
272 Ok(Self {
273 created,
274 metadata: metadata.clone(),
275 graph,
276 context,
277 })
278 }
279
280 pub fn name(&self) -> &str {
282 self.metadata.model_name()
283 }
284
285 pub fn alias(&self) -> &str {
287 self.metadata.model_alias()
288 }
289
290 pub fn update_metadata(&mut self) -> Result<(), LlamaCoreError> {
292 #[cfg(feature = "logging")]
293 info!(target: "stdout", "Update metadata for the model named {}", self.name());
294
295 let config = match serde_json::to_string(&self.metadata) {
297 Ok(config) => config,
298 Err(e) => {
299 let err_msg = format!("Failed to update metadta. Reason: Fail to serialize metadata to a JSON string. {}", e);
300
301 #[cfg(feature = "logging")]
302 error!(target: "stdout", "{}", &err_msg);
303
304 return Err(LlamaCoreError::Operation(err_msg));
305 }
306 };
307
308 let res = set_tensor_data_u8(self, 1, config.as_bytes());
309
310 #[cfg(feature = "logging")]
311 info!(target: "stdout", "Metadata updated successfully.");
312
313 res
314 }
315
316 pub fn set_input<T: Sized>(
318 &mut self,
319 index: usize,
320 tensor_type: TensorType,
321 dimensions: &[usize],
322 data: impl AsRef<[T]>,
323 ) -> Result<(), WasiNnError> {
324 self.context.set_input(index, tensor_type, dimensions, data)
325 }
326
327 pub fn compute(&mut self) -> Result<(), WasiNnError> {
329 self.context.compute()
330 }
331
332 pub fn compute_single(&mut self) -> Result<(), WasiNnError> {
336 self.context.compute_single()
337 }
338
339 pub fn get_output<T: Sized>(
341 &self,
342 index: usize,
343 out_buffer: &mut [T],
344 ) -> Result<usize, WasiNnError> {
345 self.context.get_output(index, out_buffer)
346 }
347
348 pub fn get_output_single<T: Sized>(
352 &self,
353 index: usize,
354 out_buffer: &mut [T],
355 ) -> Result<usize, WasiNnError> {
356 self.context.get_output_single(index, out_buffer)
357 }
358
359 pub fn finish_single(&mut self) -> Result<(), WasiNnError> {
363 self.context.fini_single()
364 }
365}
366impl<M: BaseMetadata + serde::Serialize + Clone + Default> Drop for Graph<M> {
367 fn drop(&mut self) {
368 if let Err(e) = self.graph.unload() {
370 let err_msg = format!("Failed to unload the wasi-nn graph. Reason: {}", e);
371
372 #[cfg(feature = "logging")]
373 error!(target: "stdout", "{}", err_msg);
374
375 #[cfg(not(feature = "logging"))]
376 eprintln!("{}", err_msg);
377 }
378 }
379}
380
381#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
383pub enum EngineType {
384 Ggml,
385 Whisper,
386 Piper,
387}