1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
4
5#[cfg(feature = "logging")]
6#[macro_use]
7extern crate log;
8
9pub mod audio;
10pub mod chat;
11pub mod completions;
12pub mod embeddings;
13pub mod error;
14pub mod files;
15pub mod graph;
16pub mod images;
17pub mod metadata;
18pub mod models;
19pub mod tts;
20pub mod utils;
21
22pub use error::LlamaCoreError;
23pub use graph::{EngineType, Graph, GraphBuilder};
24#[cfg(feature = "whisper")]
25use metadata::whisper::WhisperMetadata;
26pub use metadata::{
27 ggml::{GgmlMetadata, GgmlTtsMetadata},
28 piper::PiperMetadata,
29 BaseMetadata,
30};
31use once_cell::sync::OnceCell;
32use std::{
33 collections::HashMap,
34 path::Path,
35 sync::{Mutex, RwLock},
36};
37use utils::{get_output_buffer, RunningMode};
38use wasmedge_stable_diffusion::*;
39
40pub(crate) static CHAT_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlMetadata>>>> =
42 OnceCell::new();
43pub(crate) static EMBEDDING_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlMetadata>>>> =
45 OnceCell::new();
46pub(crate) static TTS_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlTtsMetadata>>>> =
48 OnceCell::new();
49pub(crate) static CACHED_UTF8_ENCODINGS: OnceCell<Mutex<Vec<u8>>> = OnceCell::new();
51pub(crate) static RUNNING_MODE: OnceCell<RwLock<RunningMode>> = OnceCell::new();
53pub(crate) static SD_TEXT_TO_IMAGE: OnceCell<Mutex<TextToImage>> = OnceCell::new();
55pub(crate) static SD_IMAGE_TO_IMAGE: OnceCell<Mutex<ImageToImage>> = OnceCell::new();
57#[cfg(feature = "whisper")]
59pub(crate) static AUDIO_GRAPH: OnceCell<Mutex<Graph<WhisperMetadata>>> = OnceCell::new();
60pub(crate) static PIPER_GRAPH: OnceCell<Mutex<Graph<PiperMetadata>>> = OnceCell::new();
62
63pub(crate) const MAX_BUFFER_SIZE: usize = 2usize.pow(14) * 15 + 128;
64pub(crate) const OUTPUT_TENSOR: usize = 0;
65const PLUGIN_VERSION: usize = 1;
66
67pub const ARCHIVES_DIR: &str = "archives";
69
70pub fn init_ggml_chat_context(metadata_for_chats: &[GgmlMetadata]) -> Result<(), LlamaCoreError> {
72 #[cfg(feature = "logging")]
73 info!(target: "stdout", "Initializing the core context");
74
75 if metadata_for_chats.is_empty() {
76 let err_msg = "The metadata for chat models is empty";
77
78 #[cfg(feature = "logging")]
79 error!(target: "stdout", "{err_msg}");
80
81 return Err(LlamaCoreError::InitContext(err_msg.into()));
82 }
83
84 let mut chat_graphs = HashMap::new();
85 for metadata in metadata_for_chats {
86 let graph = Graph::new(metadata.clone())?;
87
88 chat_graphs.insert(graph.name().to_string(), graph);
89 }
90 CHAT_GRAPHS.set(Mutex::new(chat_graphs)).map_err(|_| {
91 let err_msg = "Failed to initialize the core context. Reason: The `CHAT_GRAPHS` has already been initialized";
92
93 #[cfg(feature = "logging")]
94 error!(target: "stdout", "{err_msg}");
95
96 LlamaCoreError::InitContext(err_msg.into())
97 })?;
98
99 let running_mode = RunningMode::CHAT;
101 match RUNNING_MODE.get() {
102 Some(mode) => {
103 let mut mode = mode.write().unwrap();
104 *mode |= running_mode;
105 }
106 None => {
107 RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
108 let err_msg = "Failed to initialize the chat context. Reason: The `RUNNING_MODE` has already been initialized";
109
110 #[cfg(feature = "logging")]
111 error!(target: "stdout", "{err_msg}");
112
113 LlamaCoreError::InitContext(err_msg.into())
114 })?;
115 }
116 }
117
118 Ok(())
119}
120
121pub fn init_ggml_embeddings_context(
123 metadata_for_embeddings: &[GgmlMetadata],
124) -> Result<(), LlamaCoreError> {
125 #[cfg(feature = "logging")]
126 info!(target: "stdout", "Initializing the embeddings context");
127
128 if metadata_for_embeddings.is_empty() {
129 let err_msg = "The metadata for chat models is empty";
130
131 #[cfg(feature = "logging")]
132 error!(target: "stdout", "{err_msg}");
133
134 return Err(LlamaCoreError::InitContext(err_msg.into()));
135 }
136
137 let mut embedding_graphs = HashMap::new();
138 for metadata in metadata_for_embeddings {
139 let graph = Graph::new(metadata.clone())?;
140
141 embedding_graphs.insert(graph.name().to_string(), graph);
142 }
143 EMBEDDING_GRAPHS
144 .set(Mutex::new(embedding_graphs))
145 .map_err(|_| {
146 let err_msg = "Failed to initialize the core context. Reason: The `EMBEDDING_GRAPHS` has already been initialized";
147
148 #[cfg(feature = "logging")]
149 error!(target: "stdout", "{err_msg}");
150
151 LlamaCoreError::InitContext(err_msg.into())
152 })?;
153
154 let running_mode = RunningMode::EMBEDDINGS;
156 match RUNNING_MODE.get() {
157 Some(mode) => {
158 let mut mode = mode.write().unwrap();
159 *mode |= running_mode;
160 }
161 None => {
162 RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
163 let err_msg = "Failed to initialize the embeddings context. Reason: The `RUNNING_MODE` has already been initialized";
164
165 #[cfg(feature = "logging")]
166 error!(target: "stdout", "{err_msg}");
167
168 LlamaCoreError::InitContext(err_msg.into())
169 })?;
170 }
171 }
172
173 Ok(())
174}
175
176pub fn init_ggml_tts_context(metadata_for_tts: &[GgmlTtsMetadata]) -> Result<(), LlamaCoreError> {
178 #[cfg(feature = "logging")]
179 info!(target: "stdout", "Initializing the TTS context");
180
181 if metadata_for_tts.is_empty() {
182 let err_msg = "The metadata for tts models is empty";
183
184 #[cfg(feature = "logging")]
185 error!(target: "stdout", "{err_msg}");
186
187 return Err(LlamaCoreError::InitContext(err_msg.into()));
188 }
189
190 let mut tts_graphs = HashMap::new();
191 for metadata in metadata_for_tts {
192 let graph = Graph::new(metadata.clone())?;
193
194 tts_graphs.insert(graph.name().to_string(), graph);
195 }
196 TTS_GRAPHS.set(Mutex::new(tts_graphs)).map_err(|_| {
197 let err_msg = "Failed to initialize the core context. Reason: The `TTS_GRAPHS` has already been initialized";
198
199 #[cfg(feature = "logging")]
200 error!(target: "stdout", "{err_msg}");
201
202 LlamaCoreError::InitContext(err_msg.into())
203 })?;
204
205 let running_mode = RunningMode::TTS;
207 match RUNNING_MODE.get() {
208 Some(mode) => {
209 let mut mode = mode.write().unwrap();
210 *mode |= running_mode;
211 }
212 None => {
213 RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
214 let err_msg = "Failed to initialize the embeddings context. Reason: The `RUNNING_MODE` has already been initialized";
215
216 #[cfg(feature = "logging")]
217 error!(target: "stdout", "{err_msg}");
218
219 LlamaCoreError::InitContext(err_msg.into())
220 })?;
221 }
222 }
223
224 Ok(())
225}
226
227pub fn get_plugin_info() -> Result<PluginInfo, LlamaCoreError> {
231 #[cfg(feature = "logging")]
232 debug!(target: "stdout", "Getting the plugin info");
233
234 let running_mode = running_mode()?;
235
236 if running_mode.contains(RunningMode::CHAT) || running_mode.contains(RunningMode::RAG) {
237 let chat_graphs = match CHAT_GRAPHS.get() {
238 Some(chat_graphs) => chat_graphs,
239 None => {
240 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
241
242 #[cfg(feature = "logging")]
243 error!(target: "stdout", "{err_msg}");
244
245 return Err(LlamaCoreError::Operation(err_msg.into()));
246 }
247 };
248
249 let chat_graphs = chat_graphs.lock().map_err(|e| {
250 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
251
252 #[cfg(feature = "logging")]
253 error!(target: "stdout", "{}", &err_msg);
254
255 LlamaCoreError::Operation(err_msg)
256 })?;
257
258 let graph = match chat_graphs.values().next() {
259 Some(graph) => graph,
260 None => {
261 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
262
263 #[cfg(feature = "logging")]
264 error!(target: "stdout", "{err_msg}");
265
266 return Err(LlamaCoreError::Operation(err_msg.into()));
267 }
268 };
269
270 get_plugin_info_by_graph(graph)
271 } else if running_mode.contains(RunningMode::EMBEDDINGS) {
272 let embedding_graphs = match EMBEDDING_GRAPHS.get() {
273 Some(embedding_graphs) => embedding_graphs,
274 None => {
275 let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
276
277 #[cfg(feature = "logging")]
278 error!(target: "stdout", "{err_msg}");
279
280 return Err(LlamaCoreError::Operation(err_msg.into()));
281 }
282 };
283
284 let embedding_graphs = embedding_graphs.lock().map_err(|e| {
285 let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {e}");
286
287 #[cfg(feature = "logging")]
288 error!(target: "stdout", "{}", &err_msg);
289
290 LlamaCoreError::Operation(err_msg)
291 })?;
292
293 let graph = match embedding_graphs.values().next() {
294 Some(graph) => graph,
295 None => {
296 let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
297
298 #[cfg(feature = "logging")]
299 error!(target: "stdout", "{err_msg}");
300
301 return Err(LlamaCoreError::Operation(err_msg.into()));
302 }
303 };
304
305 get_plugin_info_by_graph(graph)
306 } else if running_mode.contains(RunningMode::TTS) {
307 let tts_graphs = match TTS_GRAPHS.get() {
308 Some(tts_graphs) => tts_graphs,
309 None => {
310 let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
311
312 #[cfg(feature = "logging")]
313 error!(target: "stdout", "{err_msg}");
314
315 return Err(LlamaCoreError::Operation(err_msg.into()));
316 }
317 };
318
319 let tts_graphs = tts_graphs.lock().map_err(|e| {
320 let err_msg = format!("Fail to acquire the lock of `TTS_GRAPHS`. {e}");
321
322 #[cfg(feature = "logging")]
323 error!(target: "stdout", "{}", &err_msg);
324
325 LlamaCoreError::Operation(err_msg)
326 })?;
327
328 let graph = match tts_graphs.values().next() {
329 Some(graph) => graph,
330 None => {
331 let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
332
333 #[cfg(feature = "logging")]
334 error!(target: "stdout", "{err_msg}");
335
336 return Err(LlamaCoreError::Operation(err_msg.into()));
337 }
338 };
339
340 get_plugin_info_by_graph(graph)
341 } else {
342 let err_msg = "RUNNING_MODE is not set";
343
344 #[cfg(feature = "logging")]
345 error!(target: "stdout", "{err_msg}");
346
347 Err(LlamaCoreError::Operation(err_msg.into()))
348 }
349}
350
351fn get_plugin_info_by_graph<M: BaseMetadata + serde::Serialize + Clone + Default>(
352 graph: &Graph<M>,
353) -> Result<PluginInfo, LlamaCoreError> {
354 #[cfg(feature = "logging")]
355 debug!(target: "stdout", "Getting the plugin info by the graph named {}", graph.name());
356
357 let output_buffer = get_output_buffer(graph, PLUGIN_VERSION)?;
359 let metadata: serde_json::Value = serde_json::from_slice(&output_buffer[..]).map_err(|e| {
360 let err_msg = format!("Fail to deserialize the plugin metadata. {e}");
361
362 #[cfg(feature = "logging")]
363 error!(target: "stdout", "{}", &err_msg);
364
365 LlamaCoreError::Operation(err_msg)
366 })?;
367
368 let plugin_build_number = match metadata.get("llama_build_number") {
370 Some(value) => match value.as_u64() {
371 Some(number) => number,
372 None => {
373 let err_msg = "Failed to convert the build number of the plugin to u64";
374
375 #[cfg(feature = "logging")]
376 error!(target: "stdout", "{err_msg}");
377
378 return Err(LlamaCoreError::Operation(err_msg.into()));
379 }
380 },
381 None => {
382 let err_msg = "Metadata does not have the field `llama_build_number`.";
383
384 #[cfg(feature = "logging")]
385 error!(target: "stdout", "{err_msg}");
386
387 return Err(LlamaCoreError::Operation(err_msg.into()));
388 }
389 };
390
391 let plugin_commit = match metadata.get("llama_commit") {
393 Some(value) => match value.as_str() {
394 Some(commit) => commit,
395 None => {
396 let err_msg = "Failed to convert the commit id of the plugin to string";
397
398 #[cfg(feature = "logging")]
399 error!(target: "stdout", "{err_msg}");
400
401 return Err(LlamaCoreError::Operation(err_msg.into()));
402 }
403 },
404 None => {
405 let err_msg = "Metadata does not have the field `llama_commit`.";
406
407 #[cfg(feature = "logging")]
408 error!(target: "stdout", "{err_msg}");
409
410 return Err(LlamaCoreError::Operation(err_msg.into()));
411 }
412 };
413
414 #[cfg(feature = "logging")]
415 debug!(target: "stdout", "Plugin info: b{plugin_build_number}(commit {plugin_commit})");
416
417 Ok(PluginInfo {
418 build_number: plugin_build_number,
419 commit_id: plugin_commit.to_string(),
420 })
421}
422
423#[derive(Debug, Clone)]
425pub struct PluginInfo {
426 pub build_number: u64,
427 pub commit_id: String,
428}
429impl std::fmt::Display for PluginInfo {
430 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
431 write!(
432 f,
433 "wasinn-ggml plugin: b{}(commit {})",
434 self.build_number, self.commit_id
435 )
436 }
437}
438
439pub fn running_mode() -> Result<RunningMode, LlamaCoreError> {
441 #[cfg(feature = "logging")]
442 debug!(target: "stdout", "Get the running mode.");
443
444 match RUNNING_MODE.get() {
445 Some(mode) => match mode.read() {
446 Ok(mode) => Ok(*mode),
447 Err(e) => {
448 let err_msg = format!("Fail to get the underlying value of `RUNNING_MODE`. {e}");
449
450 #[cfg(feature = "logging")]
451 error!(target: "stdout", "{err_msg}");
452
453 Err(LlamaCoreError::Operation(err_msg))
454 }
455 },
456 None => {
457 let err_msg = "Fail to get the underlying value of `RUNNING_MODE`.";
458
459 #[cfg(feature = "logging")]
460 error!(target: "stdout", "{err_msg}");
461
462 Err(LlamaCoreError::Operation(err_msg.into()))
463 }
464 }
465}
466
467#[allow(clippy::too_many_arguments)]
487pub fn init_sd_context_with_full_model(
488 model_file: impl AsRef<str>,
489 lora_model_dir: Option<&str>,
490 controlnet_path: Option<&str>,
491 controlnet_on_cpu: bool,
492 clip_on_cpu: bool,
493 vae_on_cpu: bool,
494 n_threads: i32,
495 task: StableDiffusionTask,
496) -> Result<(), LlamaCoreError> {
497 #[cfg(feature = "logging")]
498 info!(target: "stdout", "Initializing the stable diffusion context with the full model");
499
500 let control_net_on_cpu = match controlnet_path {
501 Some(path) if !path.is_empty() => controlnet_on_cpu,
502 _ => false,
503 };
504
505 if task == StableDiffusionTask::Full || task == StableDiffusionTask::TextToImage {
507 let sd = SDBuidler::new(Task::TextToImage, model_file.as_ref())
508 .map_err(|e| {
509 let err_msg =
510 format!("Failed to initialize the stable diffusion context. Reason: {e}");
511
512 #[cfg(feature = "logging")]
513 error!(target: "stdout", "{err_msg}");
514
515 LlamaCoreError::InitContext(err_msg)
516 })?
517 .with_lora_model_dir(lora_model_dir.unwrap_or_default())
518 .map_err(|e| {
519 let err_msg =
520 format!("Failed to initialize the stable diffusion context. Reason: {e}");
521
522 #[cfg(feature = "logging")]
523 error!(target: "stdout", "{err_msg}");
524
525 LlamaCoreError::InitContext(err_msg)
526 })?
527 .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
528 .map_err(|e| {
529 let err_msg =
530 format!("Failed to initialize the stable diffusion context. Reason: {e}");
531
532 #[cfg(feature = "logging")]
533 error!(target: "stdout", "{err_msg}");
534
535 LlamaCoreError::InitContext(err_msg)
536 })?
537 .clip_on_cpu(clip_on_cpu)
538 .vae_on_cpu(vae_on_cpu)
539 .with_n_threads(n_threads)
540 .build();
541
542 #[cfg(feature = "logging")]
543 info!(target: "stdout", "sd: {:?}", &sd);
544
545 let ctx = sd.create_context().map_err(|e| {
546 let err_msg = format!("Fail to create the context. {e}");
547
548 #[cfg(feature = "logging")]
549 error!(target: "stdout", "{}", &err_msg);
550
551 LlamaCoreError::InitContext(err_msg)
552 })?;
553
554 let ctx = match ctx {
555 Context::TextToImage(ctx) => ctx,
556 _ => {
557 let err_msg = "Fail to get the context for the text-to-image task";
558
559 #[cfg(feature = "logging")]
560 error!(target: "stdout", "{err_msg}");
561
562 return Err(LlamaCoreError::InitContext(err_msg.into()));
563 }
564 };
565
566 #[cfg(feature = "logging")]
567 info!(target: "stdout", "sd text_to_image context: {:?}", &ctx);
568
569 SD_TEXT_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
570 let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_TEXT_TO_IMAGE` has already been initialized";
571
572 #[cfg(feature = "logging")]
573 error!(target: "stdout", "{err_msg}");
574
575 LlamaCoreError::InitContext(err_msg.into())
576 })?;
577
578 #[cfg(feature = "logging")]
579 info!(target: "stdout", "The stable diffusion text-to-image context has been initialized");
580 }
581
582 if task == StableDiffusionTask::Full || task == StableDiffusionTask::ImageToImage {
584 let sd = SDBuidler::new(Task::ImageToImage, model_file.as_ref())
585 .map_err(|e| {
586 let err_msg =
587 format!("Failed to initialize the stable diffusion context. Reason: {e}");
588
589 #[cfg(feature = "logging")]
590 error!(target: "stdout", "{err_msg}");
591
592 LlamaCoreError::InitContext(err_msg)
593 })?
594 .with_lora_model_dir(lora_model_dir.unwrap_or_default())
595 .map_err(|e| {
596 let err_msg =
597 format!("Failed to initialize the stable diffusion context. Reason: {e}");
598
599 #[cfg(feature = "logging")]
600 error!(target: "stdout", "{err_msg}");
601
602 LlamaCoreError::InitContext(err_msg)
603 })?
604 .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
605 .map_err(|e| {
606 let err_msg =
607 format!("Failed to initialize the stable diffusion context. Reason: {e}");
608
609 #[cfg(feature = "logging")]
610 error!(target: "stdout", "{err_msg}");
611
612 LlamaCoreError::InitContext(err_msg)
613 })?
614 .clip_on_cpu(clip_on_cpu)
615 .vae_on_cpu(vae_on_cpu)
616 .with_n_threads(n_threads)
617 .build();
618
619 #[cfg(feature = "logging")]
620 info!(target: "stdout", "sd: {:?}", &sd);
621
622 let ctx = sd.create_context().map_err(|e| {
623 let err_msg = format!("Fail to create the context. {e}");
624
625 #[cfg(feature = "logging")]
626 error!(target: "stdout", "{}", &err_msg);
627
628 LlamaCoreError::InitContext(err_msg)
629 })?;
630
631 let ctx = match ctx {
632 Context::ImageToImage(ctx) => ctx,
633 _ => {
634 let err_msg = "Fail to get the context for the image-to-image task";
635
636 #[cfg(feature = "logging")]
637 error!(target: "stdout", "{err_msg}");
638
639 return Err(LlamaCoreError::InitContext(err_msg.into()));
640 }
641 };
642
643 #[cfg(feature = "logging")]
644 info!(target: "stdout", "sd image_to_image context: {:?}", &ctx);
645
646 SD_IMAGE_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
647 let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_IMAGE_TO_IMAGE` has already been initialized";
648
649 #[cfg(feature = "logging")]
650 error!(target: "stdout", "{err_msg}");
651
652 LlamaCoreError::InitContext(err_msg.into())
653 })?;
654
655 #[cfg(feature = "logging")]
656 info!(target: "stdout", "The stable diffusion image-to-image context has been initialized");
657 }
658
659 Ok(())
660}
661
662#[allow(clippy::too_many_arguments)]
688pub fn init_sd_context_with_standalone_model(
689 model_file: impl AsRef<str>,
690 vae: impl AsRef<str>,
691 clip_l: impl AsRef<str>,
692 t5xxl: impl AsRef<str>,
693 lora_model_dir: Option<&str>,
694 controlnet_path: Option<&str>,
695 controlnet_on_cpu: bool,
696 clip_on_cpu: bool,
697 vae_on_cpu: bool,
698 n_threads: i32,
699 task: StableDiffusionTask,
700) -> Result<(), LlamaCoreError> {
701 #[cfg(feature = "logging")]
702 info!(target: "stdout", "Initializing the stable diffusion context with the standalone diffusion model");
703
704 let control_net_on_cpu = match controlnet_path {
705 Some(path) if !path.is_empty() => controlnet_on_cpu,
706 _ => false,
707 };
708
709 if task == StableDiffusionTask::Full || task == StableDiffusionTask::TextToImage {
711 let sd = SDBuidler::new_with_standalone_model(Task::TextToImage, model_file.as_ref())
712 .map_err(|e| {
713 let err_msg =
714 format!("Failed to initialize the stable diffusion context. Reason: {e}");
715
716 #[cfg(feature = "logging")]
717 error!(target: "stdout", "{err_msg}");
718
719 LlamaCoreError::InitContext(err_msg)
720 })?
721 .with_vae_path(vae.as_ref())
722 .map_err(|e| {
723 let err_msg =
724 format!("Failed to initialize the stable diffusion context. Reason: {e}");
725
726 #[cfg(feature = "logging")]
727 error!(target: "stdout", "{err_msg}");
728
729 LlamaCoreError::InitContext(err_msg)
730 })?
731 .with_clip_l_path(clip_l.as_ref())
732 .map_err(|e| {
733 let err_msg =
734 format!("Failed to initialize the stable diffusion context. Reason: {e}");
735
736 #[cfg(feature = "logging")]
737 error!(target: "stdout", "{err_msg}");
738
739 LlamaCoreError::InitContext(err_msg)
740 })?
741 .with_t5xxl_path(t5xxl.as_ref())
742 .map_err(|e| {
743 let err_msg =
744 format!("Failed to initialize the stable diffusion context. Reason: {e}");
745
746 #[cfg(feature = "logging")]
747 error!(target: "stdout", "{err_msg}");
748
749 LlamaCoreError::InitContext(err_msg)
750 })?
751 .with_lora_model_dir(lora_model_dir.unwrap_or_default())
752 .map_err(|e| {
753 let err_msg =
754 format!("Failed to initialize the stable diffusion context. Reason: {e}");
755
756 #[cfg(feature = "logging")]
757 error!(target: "stdout", "{err_msg}");
758
759 LlamaCoreError::InitContext(err_msg)
760 })?
761 .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
762 .map_err(|e| {
763 let err_msg =
764 format!("Failed to initialize the stable diffusion context. Reason: {e}");
765
766 #[cfg(feature = "logging")]
767 error!(target: "stdout", "{err_msg}");
768
769 LlamaCoreError::InitContext(err_msg)
770 })?
771 .clip_on_cpu(clip_on_cpu)
772 .vae_on_cpu(vae_on_cpu)
773 .with_n_threads(n_threads)
774 .build();
775
776 #[cfg(feature = "logging")]
777 info!(target: "stdout", "sd: {:?}", &sd);
778
779 let ctx = sd.create_context().map_err(|e| {
780 let err_msg = format!("Fail to create the context. {e}");
781
782 #[cfg(feature = "logging")]
783 error!(target: "stdout", "{}", &err_msg);
784
785 LlamaCoreError::InitContext(err_msg)
786 })?;
787
788 let ctx = match ctx {
789 Context::TextToImage(ctx) => ctx,
790 _ => {
791 let err_msg = "Fail to get the context for the text-to-image task";
792
793 #[cfg(feature = "logging")]
794 error!(target: "stdout", "{err_msg}");
795
796 return Err(LlamaCoreError::InitContext(err_msg.into()));
797 }
798 };
799
800 #[cfg(feature = "logging")]
801 info!(target: "stdout", "sd text_to_image context: {:?}", &ctx);
802
803 SD_TEXT_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
804 let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_TEXT_TO_IMAGE` has already been initialized";
805
806 #[cfg(feature = "logging")]
807 error!(target: "stdout", "{err_msg}");
808
809 LlamaCoreError::InitContext(err_msg.into())
810 })?;
811
812 #[cfg(feature = "logging")]
813 info!(target: "stdout", "The stable diffusion text-to-image context has been initialized");
814 }
815
816 if task == StableDiffusionTask::Full || task == StableDiffusionTask::ImageToImage {
818 let sd = SDBuidler::new_with_standalone_model(Task::ImageToImage, model_file.as_ref())
819 .map_err(|e| {
820 let err_msg =
821 format!("Failed to initialize the stable diffusion context. Reason: {e}");
822
823 #[cfg(feature = "logging")]
824 error!(target: "stdout", "{err_msg}");
825
826 LlamaCoreError::InitContext(err_msg)
827 })?
828 .with_vae_path(vae.as_ref())
829 .map_err(|e| {
830 let err_msg =
831 format!("Failed to initialize the stable diffusion context. Reason: {e}");
832
833 #[cfg(feature = "logging")]
834 error!(target: "stdout", "{err_msg}");
835
836 LlamaCoreError::InitContext(err_msg)
837 })?
838 .with_clip_l_path(clip_l.as_ref())
839 .map_err(|e| {
840 let err_msg =
841 format!("Failed to initialize the stable diffusion context. Reason: {e}");
842
843 #[cfg(feature = "logging")]
844 error!(target: "stdout", "{err_msg}");
845
846 LlamaCoreError::InitContext(err_msg)
847 })?
848 .with_t5xxl_path(t5xxl.as_ref())
849 .map_err(|e| {
850 let err_msg =
851 format!("Failed to initialize the stable diffusion context. Reason: {e}");
852
853 #[cfg(feature = "logging")]
854 error!(target: "stdout", "{err_msg}");
855
856 LlamaCoreError::InitContext(err_msg)
857 })?
858 .with_lora_model_dir(lora_model_dir.unwrap_or_default())
859 .map_err(|e| {
860 let err_msg =
861 format!("Failed to initialize the stable diffusion context. Reason: {e}");
862
863 #[cfg(feature = "logging")]
864 error!(target: "stdout", "{err_msg}");
865
866 LlamaCoreError::InitContext(err_msg)
867 })?
868 .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
869 .map_err(|e| {
870 let err_msg =
871 format!("Failed to initialize the stable diffusion context. Reason: {e}");
872
873 #[cfg(feature = "logging")]
874 error!(target: "stdout", "{err_msg}");
875
876 LlamaCoreError::InitContext(err_msg)
877 })?
878 .clip_on_cpu(clip_on_cpu)
879 .vae_on_cpu(vae_on_cpu)
880 .with_n_threads(n_threads)
881 .build();
882
883 #[cfg(feature = "logging")]
884 info!(target: "stdout", "sd: {:?}", &sd);
885
886 let ctx = sd.create_context().map_err(|e| {
887 let err_msg = format!("Fail to create the context. {e}");
888
889 #[cfg(feature = "logging")]
890 error!(target: "stdout", "{}", &err_msg);
891
892 LlamaCoreError::InitContext(err_msg)
893 })?;
894
895 let ctx = match ctx {
896 Context::ImageToImage(ctx) => ctx,
897 _ => {
898 let err_msg = "Fail to get the context for the image-to-image task";
899
900 #[cfg(feature = "logging")]
901 error!(target: "stdout", "{err_msg}");
902
903 return Err(LlamaCoreError::InitContext(err_msg.into()));
904 }
905 };
906
907 #[cfg(feature = "logging")]
908 info!(target: "stdout", "sd image_to_image context: {:?}", &ctx);
909
910 SD_IMAGE_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
911 let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_IMAGE_TO_IMAGE` has already been initialized";
912
913 #[cfg(feature = "logging")]
914 error!(target: "stdout", "{err_msg}");
915
916 LlamaCoreError::InitContext(err_msg.into())
917 })?;
918
919 #[cfg(feature = "logging")]
920 info!(target: "stdout", "The stable diffusion image-to-image context has been initialized");
921 }
922
923 Ok(())
924}
925
926#[derive(Clone, Debug, Copy, PartialEq, Eq)]
928pub enum StableDiffusionTask {
929 TextToImage,
931 ImageToImage,
933 Full,
935}
936
937#[cfg(feature = "whisper")]
939pub fn init_whisper_context(whisper_metadata: &WhisperMetadata) -> Result<(), LlamaCoreError> {
940 let graph = GraphBuilder::new(EngineType::Whisper)?
942 .with_config(whisper_metadata.clone())?
943 .use_cpu()
944 .build_from_files([&whisper_metadata.model_path])?;
945
946 match AUDIO_GRAPH.get() {
947 Some(mutex_graph) => {
948 #[cfg(feature = "logging")]
949 info!(target: "stdout", "Re-initialize the audio context");
950
951 match mutex_graph.lock() {
952 Ok(mut locked_graph) => *locked_graph = graph,
953 Err(e) => {
954 let err_msg = format!("Failed to lock the graph. Reason: {e}");
955
956 #[cfg(feature = "logging")]
957 error!(target: "stdout", "{err_msg}");
958
959 return Err(LlamaCoreError::InitContext(err_msg));
960 }
961 }
962 }
963 None => {
964 #[cfg(feature = "logging")]
965 info!(target: "stdout", "Initialize the audio context");
966
967 AUDIO_GRAPH.set(Mutex::new(graph)).map_err(|_| {
968 let err_msg = "Failed to initialize the audio context. Reason: The `AUDIO_GRAPH` has already been initialized";
969
970 #[cfg(feature = "logging")]
971 error!(target: "stdout", "{err_msg}");
972
973 LlamaCoreError::InitContext(err_msg.into())
974 })?;
975 }
976 }
977
978 #[cfg(feature = "logging")]
979 info!(target: "stdout", "The audio context has been initialized");
980
981 Ok(())
982}
983
984pub fn init_piper_context(
995 piper_metadata: &PiperMetadata,
996 voice_model: impl AsRef<Path>,
997 voice_config: impl AsRef<Path>,
998 espeak_ng_data: impl AsRef<Path>,
999) -> Result<(), LlamaCoreError> {
1000 #[cfg(feature = "logging")]
1001 info!(target: "stdout", "Initializing the piper context");
1002
1003 let config = serde_json::json!({
1004 "model": voice_model.as_ref().to_owned(),
1005 "config": voice_config.as_ref().to_owned(),
1006 "espeak_data": espeak_ng_data.as_ref().to_owned(),
1007 });
1008
1009 let graph = GraphBuilder::new(EngineType::Piper)?
1011 .with_config(piper_metadata.clone())?
1012 .use_cpu()
1013 .build_from_buffer([config.to_string()])?;
1014
1015 PIPER_GRAPH.set(Mutex::new(graph)).map_err(|_| {
1016 let err_msg = "Failed to initialize the piper context. Reason: The `PIPER_GRAPH` has already been initialized";
1017
1018 #[cfg(feature = "logging")]
1019 error!(target: "stdout", "{err_msg}");
1020
1021 LlamaCoreError::InitContext(err_msg.into())
1022 })?;
1023
1024 #[cfg(feature = "logging")]
1025 info!(target: "stdout", "The piper context has been initialized");
1026
1027 Ok(())
1028}