1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
4
5#[cfg(feature = "logging")]
6#[macro_use]
7extern crate log;
8
9pub mod audio;
10pub mod chat;
11pub mod completions;
12pub mod embeddings;
13pub mod error;
14pub mod files;
15pub mod graph;
16pub mod images;
17pub mod metadata;
18pub mod models;
19#[cfg(feature = "rag")]
20#[cfg_attr(docsrs, doc(cfg(feature = "rag")))]
21pub mod rag;
22#[cfg(feature = "search")]
23#[cfg_attr(docsrs, doc(cfg(feature = "search")))]
24pub mod search;
25pub mod tts;
26pub mod utils;
27
28pub use error::LlamaCoreError;
29pub use graph::{EngineType, Graph, GraphBuilder};
30#[cfg(feature = "whisper")]
31use metadata::whisper::WhisperMetadata;
32pub use metadata::{
33 ggml::{GgmlMetadata, GgmlTtsMetadata},
34 piper::PiperMetadata,
35 BaseMetadata,
36};
37use once_cell::sync::OnceCell;
38use std::{
39 collections::HashMap,
40 path::Path,
41 sync::{Mutex, RwLock},
42};
43use utils::{get_output_buffer, RunningMode};
44use wasmedge_stable_diffusion::*;
45
46pub(crate) static CHAT_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlMetadata>>>> =
48 OnceCell::new();
49pub(crate) static EMBEDDING_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlMetadata>>>> =
51 OnceCell::new();
52pub(crate) static TTS_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlTtsMetadata>>>> =
54 OnceCell::new();
55pub(crate) static CACHED_UTF8_ENCODINGS: OnceCell<Mutex<Vec<u8>>> = OnceCell::new();
57pub(crate) static RUNNING_MODE: OnceCell<RwLock<RunningMode>> = OnceCell::new();
59pub(crate) static SD_TEXT_TO_IMAGE: OnceCell<Mutex<TextToImage>> = OnceCell::new();
61pub(crate) static SD_IMAGE_TO_IMAGE: OnceCell<Mutex<ImageToImage>> = OnceCell::new();
63#[cfg(feature = "whisper")]
65pub(crate) static AUDIO_GRAPH: OnceCell<Mutex<Graph<WhisperMetadata>>> = OnceCell::new();
66pub(crate) static PIPER_GRAPH: OnceCell<Mutex<Graph<PiperMetadata>>> = OnceCell::new();
68
69pub(crate) const MAX_BUFFER_SIZE: usize = 2usize.pow(14) * 15 + 128;
70pub(crate) const OUTPUT_TENSOR: usize = 0;
71const PLUGIN_VERSION: usize = 1;
72
73pub const ARCHIVES_DIR: &str = "archives";
75
76pub fn init_ggml_chat_context(metadata_for_chats: &[GgmlMetadata]) -> Result<(), LlamaCoreError> {
78 #[cfg(feature = "logging")]
79 info!(target: "stdout", "Initializing the core context");
80
81 if metadata_for_chats.is_empty() {
82 let err_msg = "The metadata for chat models is empty";
83
84 #[cfg(feature = "logging")]
85 error!(target: "stdout", "{err_msg}");
86
87 return Err(LlamaCoreError::InitContext(err_msg.into()));
88 }
89
90 let mut chat_graphs = HashMap::new();
91 for metadata in metadata_for_chats {
92 let graph = Graph::new(metadata.clone())?;
93
94 chat_graphs.insert(graph.name().to_string(), graph);
95 }
96 CHAT_GRAPHS.set(Mutex::new(chat_graphs)).map_err(|_| {
97 let err_msg = "Failed to initialize the core context. Reason: The `CHAT_GRAPHS` has already been initialized";
98
99 #[cfg(feature = "logging")]
100 error!(target: "stdout", "{err_msg}");
101
102 LlamaCoreError::InitContext(err_msg.into())
103 })?;
104
105 let running_mode = RunningMode::CHAT;
107 match RUNNING_MODE.get() {
108 Some(mode) => {
109 let mut mode = mode.write().unwrap();
110 *mode |= running_mode;
111 }
112 None => {
113 RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
114 let err_msg = "Failed to initialize the chat context. Reason: The `RUNNING_MODE` has already been initialized";
115
116 #[cfg(feature = "logging")]
117 error!(target: "stdout", "{err_msg}");
118
119 LlamaCoreError::InitContext(err_msg.into())
120 })?;
121 }
122 }
123
124 Ok(())
125}
126
127pub fn init_ggml_embeddings_context(
129 metadata_for_embeddings: &[GgmlMetadata],
130) -> Result<(), LlamaCoreError> {
131 #[cfg(feature = "logging")]
132 info!(target: "stdout", "Initializing the embeddings context");
133
134 if metadata_for_embeddings.is_empty() {
135 let err_msg = "The metadata for chat models is empty";
136
137 #[cfg(feature = "logging")]
138 error!(target: "stdout", "{err_msg}");
139
140 return Err(LlamaCoreError::InitContext(err_msg.into()));
141 }
142
143 let mut embedding_graphs = HashMap::new();
144 for metadata in metadata_for_embeddings {
145 let graph = Graph::new(metadata.clone())?;
146
147 embedding_graphs.insert(graph.name().to_string(), graph);
148 }
149 EMBEDDING_GRAPHS
150 .set(Mutex::new(embedding_graphs))
151 .map_err(|_| {
152 let err_msg = "Failed to initialize the core context. Reason: The `EMBEDDING_GRAPHS` has already been initialized";
153
154 #[cfg(feature = "logging")]
155 error!(target: "stdout", "{err_msg}");
156
157 LlamaCoreError::InitContext(err_msg.into())
158 })?;
159
160 let running_mode = RunningMode::EMBEDDINGS;
162 match RUNNING_MODE.get() {
163 Some(mode) => {
164 let mut mode = mode.write().unwrap();
165 *mode |= running_mode;
166 }
167 None => {
168 RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
169 let err_msg = "Failed to initialize the embeddings context. Reason: The `RUNNING_MODE` has already been initialized";
170
171 #[cfg(feature = "logging")]
172 error!(target: "stdout", "{err_msg}");
173
174 LlamaCoreError::InitContext(err_msg.into())
175 })?;
176 }
177 }
178
179 Ok(())
180}
181
182#[cfg(feature = "rag")]
184pub fn init_ggml_rag_context(
185 metadata_for_chats: &[GgmlMetadata],
186 metadata_for_embeddings: &[GgmlMetadata],
187) -> Result<(), LlamaCoreError> {
188 #[cfg(feature = "logging")]
189 info!(target: "stdout", "Initializing the core context for RAG scenarios");
190
191 if metadata_for_chats.is_empty() {
193 let err_msg = "The metadata for chat models is empty";
194
195 #[cfg(feature = "logging")]
196 error!(target: "stdout", "{err_msg}");
197
198 return Err(LlamaCoreError::InitContext(err_msg.into()));
199 }
200 let mut chat_graphs = HashMap::new();
201 for metadata in metadata_for_chats {
202 let graph = Graph::new(metadata.clone())?;
203
204 chat_graphs.insert(graph.name().to_string(), graph);
205 }
206 CHAT_GRAPHS.set(Mutex::new(chat_graphs)).map_err(|_| {
207 let err_msg = "Failed to initialize the core context. Reason: The `CHAT_GRAPHS` has already been initialized";
208
209 #[cfg(feature = "logging")]
210 error!(target: "stdout", "{err_msg}");
211
212 LlamaCoreError::InitContext(err_msg.into())
213 })?;
214
215 if metadata_for_embeddings.is_empty() {
217 let err_msg = "The metadata for embeddings is empty";
218
219 #[cfg(feature = "logging")]
220 error!(target: "stdout", "{err_msg}");
221
222 return Err(LlamaCoreError::InitContext(err_msg.into()));
223 }
224 let mut embedding_graphs = HashMap::new();
225 for metadata in metadata_for_embeddings {
226 let graph = Graph::new(metadata.clone())?;
227
228 embedding_graphs.insert(graph.name().to_string(), graph);
229 }
230 EMBEDDING_GRAPHS
231 .set(Mutex::new(embedding_graphs))
232 .map_err(|_| {
233 let err_msg = "Failed to initialize the core context. Reason: The `EMBEDDING_GRAPHS` has already been initialized";
234
235 #[cfg(feature = "logging")]
236 error!(target: "stdout", "{err_msg}");
237
238 LlamaCoreError::InitContext(err_msg.into())
239 })?;
240
241 let running_mode = RunningMode::RAG;
243 match RUNNING_MODE.get() {
244 Some(mode) => {
245 let mut mode = mode.write().unwrap();
246 *mode |= running_mode;
247 }
248 None => {
249 RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
250 let err_msg = "Failed to initialize the rag context. Reason: The `RUNNING_MODE` has already been initialized";
251
252 #[cfg(feature = "logging")]
253 error!(target: "stdout", "{err_msg}");
254
255 LlamaCoreError::InitContext(err_msg.into())
256 })?;
257 }
258 }
259
260 Ok(())
261}
262
263pub fn init_ggml_tts_context(metadata_for_tts: &[GgmlTtsMetadata]) -> Result<(), LlamaCoreError> {
265 #[cfg(feature = "logging")]
266 info!(target: "stdout", "Initializing the TTS context");
267
268 if metadata_for_tts.is_empty() {
269 let err_msg = "The metadata for tts models is empty";
270
271 #[cfg(feature = "logging")]
272 error!(target: "stdout", "{err_msg}");
273
274 return Err(LlamaCoreError::InitContext(err_msg.into()));
275 }
276
277 let mut tts_graphs = HashMap::new();
278 for metadata in metadata_for_tts {
279 let graph = Graph::new(metadata.clone())?;
280
281 tts_graphs.insert(graph.name().to_string(), graph);
282 }
283 TTS_GRAPHS.set(Mutex::new(tts_graphs)).map_err(|_| {
284 let err_msg = "Failed to initialize the core context. Reason: The `TTS_GRAPHS` has already been initialized";
285
286 #[cfg(feature = "logging")]
287 error!(target: "stdout", "{err_msg}");
288
289 LlamaCoreError::InitContext(err_msg.into())
290 })?;
291
292 let running_mode = RunningMode::TTS;
294 match RUNNING_MODE.get() {
295 Some(mode) => {
296 let mut mode = mode.write().unwrap();
297 *mode |= running_mode;
298 }
299 None => {
300 RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
301 let err_msg = "Failed to initialize the embeddings context. Reason: The `RUNNING_MODE` has already been initialized";
302
303 #[cfg(feature = "logging")]
304 error!(target: "stdout", "{err_msg}");
305
306 LlamaCoreError::InitContext(err_msg.into())
307 })?;
308 }
309 }
310
311 Ok(())
312}
313
314pub fn get_plugin_info() -> Result<PluginInfo, LlamaCoreError> {
318 #[cfg(feature = "logging")]
319 debug!(target: "stdout", "Getting the plugin info");
320
321 let running_mode = running_mode()?;
322
323 if running_mode.contains(RunningMode::CHAT) || running_mode.contains(RunningMode::RAG) {
324 let chat_graphs = match CHAT_GRAPHS.get() {
325 Some(chat_graphs) => chat_graphs,
326 None => {
327 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
328
329 #[cfg(feature = "logging")]
330 error!(target: "stdout", "{err_msg}");
331
332 return Err(LlamaCoreError::Operation(err_msg.into()));
333 }
334 };
335
336 let chat_graphs = chat_graphs.lock().map_err(|e| {
337 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
338
339 #[cfg(feature = "logging")]
340 error!(target: "stdout", "{}", &err_msg);
341
342 LlamaCoreError::Operation(err_msg)
343 })?;
344
345 let graph = match chat_graphs.values().next() {
346 Some(graph) => graph,
347 None => {
348 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
349
350 #[cfg(feature = "logging")]
351 error!(target: "stdout", "{err_msg}");
352
353 return Err(LlamaCoreError::Operation(err_msg.into()));
354 }
355 };
356
357 get_plugin_info_by_graph(graph)
358 } else if running_mode.contains(RunningMode::EMBEDDINGS) {
359 let embedding_graphs = match EMBEDDING_GRAPHS.get() {
360 Some(embedding_graphs) => embedding_graphs,
361 None => {
362 let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
363
364 #[cfg(feature = "logging")]
365 error!(target: "stdout", "{err_msg}");
366
367 return Err(LlamaCoreError::Operation(err_msg.into()));
368 }
369 };
370
371 let embedding_graphs = embedding_graphs.lock().map_err(|e| {
372 let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {e}");
373
374 #[cfg(feature = "logging")]
375 error!(target: "stdout", "{}", &err_msg);
376
377 LlamaCoreError::Operation(err_msg)
378 })?;
379
380 let graph = match embedding_graphs.values().next() {
381 Some(graph) => graph,
382 None => {
383 let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
384
385 #[cfg(feature = "logging")]
386 error!(target: "stdout", "{err_msg}");
387
388 return Err(LlamaCoreError::Operation(err_msg.into()));
389 }
390 };
391
392 get_plugin_info_by_graph(graph)
393 } else if running_mode.contains(RunningMode::TTS) {
394 let tts_graphs = match TTS_GRAPHS.get() {
395 Some(tts_graphs) => tts_graphs,
396 None => {
397 let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
398
399 #[cfg(feature = "logging")]
400 error!(target: "stdout", "{err_msg}");
401
402 return Err(LlamaCoreError::Operation(err_msg.into()));
403 }
404 };
405
406 let tts_graphs = tts_graphs.lock().map_err(|e| {
407 let err_msg = format!("Fail to acquire the lock of `TTS_GRAPHS`. {e}");
408
409 #[cfg(feature = "logging")]
410 error!(target: "stdout", "{}", &err_msg);
411
412 LlamaCoreError::Operation(err_msg)
413 })?;
414
415 let graph = match tts_graphs.values().next() {
416 Some(graph) => graph,
417 None => {
418 let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
419
420 #[cfg(feature = "logging")]
421 error!(target: "stdout", "{err_msg}");
422
423 return Err(LlamaCoreError::Operation(err_msg.into()));
424 }
425 };
426
427 get_plugin_info_by_graph(graph)
428 } else {
429 let err_msg = "RUNNING_MODE is not set";
430
431 #[cfg(feature = "logging")]
432 error!(target: "stdout", "{err_msg}");
433
434 Err(LlamaCoreError::Operation(err_msg.into()))
435 }
436}
437
438fn get_plugin_info_by_graph<M: BaseMetadata + serde::Serialize + Clone + Default>(
439 graph: &Graph<M>,
440) -> Result<PluginInfo, LlamaCoreError> {
441 #[cfg(feature = "logging")]
442 debug!(target: "stdout", "Getting the plugin info by the graph named {}", graph.name());
443
444 let output_buffer = get_output_buffer(graph, PLUGIN_VERSION)?;
446 let metadata: serde_json::Value = serde_json::from_slice(&output_buffer[..]).map_err(|e| {
447 let err_msg = format!("Fail to deserialize the plugin metadata. {e}");
448
449 #[cfg(feature = "logging")]
450 error!(target: "stdout", "{}", &err_msg);
451
452 LlamaCoreError::Operation(err_msg)
453 })?;
454
455 let plugin_build_number = match metadata.get("llama_build_number") {
457 Some(value) => match value.as_u64() {
458 Some(number) => number,
459 None => {
460 let err_msg = "Failed to convert the build number of the plugin to u64";
461
462 #[cfg(feature = "logging")]
463 error!(target: "stdout", "{err_msg}");
464
465 return Err(LlamaCoreError::Operation(err_msg.into()));
466 }
467 },
468 None => {
469 let err_msg = "Metadata does not have the field `llama_build_number`.";
470
471 #[cfg(feature = "logging")]
472 error!(target: "stdout", "{err_msg}");
473
474 return Err(LlamaCoreError::Operation(err_msg.into()));
475 }
476 };
477
478 let plugin_commit = match metadata.get("llama_commit") {
480 Some(value) => match value.as_str() {
481 Some(commit) => commit,
482 None => {
483 let err_msg = "Failed to convert the commit id of the plugin to string";
484
485 #[cfg(feature = "logging")]
486 error!(target: "stdout", "{err_msg}");
487
488 return Err(LlamaCoreError::Operation(err_msg.into()));
489 }
490 },
491 None => {
492 let err_msg = "Metadata does not have the field `llama_commit`.";
493
494 #[cfg(feature = "logging")]
495 error!(target: "stdout", "{err_msg}");
496
497 return Err(LlamaCoreError::Operation(err_msg.into()));
498 }
499 };
500
501 #[cfg(feature = "logging")]
502 debug!(target: "stdout", "Plugin info: b{plugin_build_number}(commit {plugin_commit})");
503
504 Ok(PluginInfo {
505 build_number: plugin_build_number,
506 commit_id: plugin_commit.to_string(),
507 })
508}
509
510#[derive(Debug, Clone)]
512pub struct PluginInfo {
513 pub build_number: u64,
514 pub commit_id: String,
515}
516impl std::fmt::Display for PluginInfo {
517 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
518 write!(
519 f,
520 "wasinn-ggml plugin: b{}(commit {})",
521 self.build_number, self.commit_id
522 )
523 }
524}
525
526pub fn running_mode() -> Result<RunningMode, LlamaCoreError> {
528 #[cfg(feature = "logging")]
529 debug!(target: "stdout", "Get the running mode.");
530
531 match RUNNING_MODE.get() {
532 Some(mode) => match mode.read() {
533 Ok(mode) => Ok(*mode),
534 Err(e) => {
535 let err_msg = format!("Fail to get the underlying value of `RUNNING_MODE`. {e}");
536
537 #[cfg(feature = "logging")]
538 error!(target: "stdout", "{err_msg}");
539
540 Err(LlamaCoreError::Operation(err_msg))
541 }
542 },
543 None => {
544 let err_msg = "Fail to get the underlying value of `RUNNING_MODE`.";
545
546 #[cfg(feature = "logging")]
547 error!(target: "stdout", "{err_msg}");
548
549 Err(LlamaCoreError::Operation(err_msg.into()))
550 }
551 }
552}
553
554#[allow(clippy::too_many_arguments)]
574pub fn init_sd_context_with_full_model(
575 model_file: impl AsRef<str>,
576 lora_model_dir: Option<&str>,
577 controlnet_path: Option<&str>,
578 controlnet_on_cpu: bool,
579 clip_on_cpu: bool,
580 vae_on_cpu: bool,
581 n_threads: i32,
582 task: StableDiffusionTask,
583) -> Result<(), LlamaCoreError> {
584 #[cfg(feature = "logging")]
585 info!(target: "stdout", "Initializing the stable diffusion context with the full model");
586
587 let control_net_on_cpu = match controlnet_path {
588 Some(path) if !path.is_empty() => controlnet_on_cpu,
589 _ => false,
590 };
591
592 if task == StableDiffusionTask::Full || task == StableDiffusionTask::TextToImage {
594 let sd = SDBuidler::new(Task::TextToImage, model_file.as_ref())
595 .map_err(|e| {
596 let err_msg =
597 format!("Failed to initialize the stable diffusion context. Reason: {e}");
598
599 #[cfg(feature = "logging")]
600 error!(target: "stdout", "{err_msg}");
601
602 LlamaCoreError::InitContext(err_msg)
603 })?
604 .with_lora_model_dir(lora_model_dir.unwrap_or_default())
605 .map_err(|e| {
606 let err_msg =
607 format!("Failed to initialize the stable diffusion context. Reason: {e}");
608
609 #[cfg(feature = "logging")]
610 error!(target: "stdout", "{err_msg}");
611
612 LlamaCoreError::InitContext(err_msg)
613 })?
614 .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
615 .map_err(|e| {
616 let err_msg =
617 format!("Failed to initialize the stable diffusion context. Reason: {e}");
618
619 #[cfg(feature = "logging")]
620 error!(target: "stdout", "{err_msg}");
621
622 LlamaCoreError::InitContext(err_msg)
623 })?
624 .clip_on_cpu(clip_on_cpu)
625 .vae_on_cpu(vae_on_cpu)
626 .with_n_threads(n_threads)
627 .build();
628
629 #[cfg(feature = "logging")]
630 info!(target: "stdout", "sd: {:?}", &sd);
631
632 let ctx = sd.create_context().map_err(|e| {
633 let err_msg = format!("Fail to create the context. {e}");
634
635 #[cfg(feature = "logging")]
636 error!(target: "stdout", "{}", &err_msg);
637
638 LlamaCoreError::InitContext(err_msg)
639 })?;
640
641 let ctx = match ctx {
642 Context::TextToImage(ctx) => ctx,
643 _ => {
644 let err_msg = "Fail to get the context for the text-to-image task";
645
646 #[cfg(feature = "logging")]
647 error!(target: "stdout", "{err_msg}");
648
649 return Err(LlamaCoreError::InitContext(err_msg.into()));
650 }
651 };
652
653 #[cfg(feature = "logging")]
654 info!(target: "stdout", "sd text_to_image context: {:?}", &ctx);
655
656 SD_TEXT_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
657 let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_TEXT_TO_IMAGE` has already been initialized";
658
659 #[cfg(feature = "logging")]
660 error!(target: "stdout", "{err_msg}");
661
662 LlamaCoreError::InitContext(err_msg.into())
663 })?;
664
665 #[cfg(feature = "logging")]
666 info!(target: "stdout", "The stable diffusion text-to-image context has been initialized");
667 }
668
669 if task == StableDiffusionTask::Full || task == StableDiffusionTask::ImageToImage {
671 let sd = SDBuidler::new(Task::ImageToImage, model_file.as_ref())
672 .map_err(|e| {
673 let err_msg =
674 format!("Failed to initialize the stable diffusion context. Reason: {e}");
675
676 #[cfg(feature = "logging")]
677 error!(target: "stdout", "{err_msg}");
678
679 LlamaCoreError::InitContext(err_msg)
680 })?
681 .with_lora_model_dir(lora_model_dir.unwrap_or_default())
682 .map_err(|e| {
683 let err_msg =
684 format!("Failed to initialize the stable diffusion context. Reason: {e}");
685
686 #[cfg(feature = "logging")]
687 error!(target: "stdout", "{err_msg}");
688
689 LlamaCoreError::InitContext(err_msg)
690 })?
691 .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
692 .map_err(|e| {
693 let err_msg =
694 format!("Failed to initialize the stable diffusion context. Reason: {e}");
695
696 #[cfg(feature = "logging")]
697 error!(target: "stdout", "{err_msg}");
698
699 LlamaCoreError::InitContext(err_msg)
700 })?
701 .clip_on_cpu(clip_on_cpu)
702 .vae_on_cpu(vae_on_cpu)
703 .with_n_threads(n_threads)
704 .build();
705
706 #[cfg(feature = "logging")]
707 info!(target: "stdout", "sd: {:?}", &sd);
708
709 let ctx = sd.create_context().map_err(|e| {
710 let err_msg = format!("Fail to create the context. {e}");
711
712 #[cfg(feature = "logging")]
713 error!(target: "stdout", "{}", &err_msg);
714
715 LlamaCoreError::InitContext(err_msg)
716 })?;
717
718 let ctx = match ctx {
719 Context::ImageToImage(ctx) => ctx,
720 _ => {
721 let err_msg = "Fail to get the context for the image-to-image task";
722
723 #[cfg(feature = "logging")]
724 error!(target: "stdout", "{err_msg}");
725
726 return Err(LlamaCoreError::InitContext(err_msg.into()));
727 }
728 };
729
730 #[cfg(feature = "logging")]
731 info!(target: "stdout", "sd image_to_image context: {:?}", &ctx);
732
733 SD_IMAGE_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
734 let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_IMAGE_TO_IMAGE` has already been initialized";
735
736 #[cfg(feature = "logging")]
737 error!(target: "stdout", "{err_msg}");
738
739 LlamaCoreError::InitContext(err_msg.into())
740 })?;
741
742 #[cfg(feature = "logging")]
743 info!(target: "stdout", "The stable diffusion image-to-image context has been initialized");
744 }
745
746 Ok(())
747}
748
749#[allow(clippy::too_many_arguments)]
775pub fn init_sd_context_with_standalone_model(
776 model_file: impl AsRef<str>,
777 vae: impl AsRef<str>,
778 clip_l: impl AsRef<str>,
779 t5xxl: impl AsRef<str>,
780 lora_model_dir: Option<&str>,
781 controlnet_path: Option<&str>,
782 controlnet_on_cpu: bool,
783 clip_on_cpu: bool,
784 vae_on_cpu: bool,
785 n_threads: i32,
786 task: StableDiffusionTask,
787) -> Result<(), LlamaCoreError> {
788 #[cfg(feature = "logging")]
789 info!(target: "stdout", "Initializing the stable diffusion context with the standalone diffusion model");
790
791 let control_net_on_cpu = match controlnet_path {
792 Some(path) if !path.is_empty() => controlnet_on_cpu,
793 _ => false,
794 };
795
796 if task == StableDiffusionTask::Full || task == StableDiffusionTask::TextToImage {
798 let sd = SDBuidler::new_with_standalone_model(Task::TextToImage, model_file.as_ref())
799 .map_err(|e| {
800 let err_msg =
801 format!("Failed to initialize the stable diffusion context. Reason: {e}");
802
803 #[cfg(feature = "logging")]
804 error!(target: "stdout", "{err_msg}");
805
806 LlamaCoreError::InitContext(err_msg)
807 })?
808 .with_vae_path(vae.as_ref())
809 .map_err(|e| {
810 let err_msg =
811 format!("Failed to initialize the stable diffusion context. Reason: {e}");
812
813 #[cfg(feature = "logging")]
814 error!(target: "stdout", "{err_msg}");
815
816 LlamaCoreError::InitContext(err_msg)
817 })?
818 .with_clip_l_path(clip_l.as_ref())
819 .map_err(|e| {
820 let err_msg =
821 format!("Failed to initialize the stable diffusion context. Reason: {e}");
822
823 #[cfg(feature = "logging")]
824 error!(target: "stdout", "{err_msg}");
825
826 LlamaCoreError::InitContext(err_msg)
827 })?
828 .with_t5xxl_path(t5xxl.as_ref())
829 .map_err(|e| {
830 let err_msg =
831 format!("Failed to initialize the stable diffusion context. Reason: {e}");
832
833 #[cfg(feature = "logging")]
834 error!(target: "stdout", "{err_msg}");
835
836 LlamaCoreError::InitContext(err_msg)
837 })?
838 .with_lora_model_dir(lora_model_dir.unwrap_or_default())
839 .map_err(|e| {
840 let err_msg =
841 format!("Failed to initialize the stable diffusion context. Reason: {e}");
842
843 #[cfg(feature = "logging")]
844 error!(target: "stdout", "{err_msg}");
845
846 LlamaCoreError::InitContext(err_msg)
847 })?
848 .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
849 .map_err(|e| {
850 let err_msg =
851 format!("Failed to initialize the stable diffusion context. Reason: {e}");
852
853 #[cfg(feature = "logging")]
854 error!(target: "stdout", "{err_msg}");
855
856 LlamaCoreError::InitContext(err_msg)
857 })?
858 .clip_on_cpu(clip_on_cpu)
859 .vae_on_cpu(vae_on_cpu)
860 .with_n_threads(n_threads)
861 .build();
862
863 #[cfg(feature = "logging")]
864 info!(target: "stdout", "sd: {:?}", &sd);
865
866 let ctx = sd.create_context().map_err(|e| {
867 let err_msg = format!("Fail to create the context. {e}");
868
869 #[cfg(feature = "logging")]
870 error!(target: "stdout", "{}", &err_msg);
871
872 LlamaCoreError::InitContext(err_msg)
873 })?;
874
875 let ctx = match ctx {
876 Context::TextToImage(ctx) => ctx,
877 _ => {
878 let err_msg = "Fail to get the context for the text-to-image task";
879
880 #[cfg(feature = "logging")]
881 error!(target: "stdout", "{err_msg}");
882
883 return Err(LlamaCoreError::InitContext(err_msg.into()));
884 }
885 };
886
887 #[cfg(feature = "logging")]
888 info!(target: "stdout", "sd text_to_image context: {:?}", &ctx);
889
890 SD_TEXT_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
891 let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_TEXT_TO_IMAGE` has already been initialized";
892
893 #[cfg(feature = "logging")]
894 error!(target: "stdout", "{err_msg}");
895
896 LlamaCoreError::InitContext(err_msg.into())
897 })?;
898
899 #[cfg(feature = "logging")]
900 info!(target: "stdout", "The stable diffusion text-to-image context has been initialized");
901 }
902
903 if task == StableDiffusionTask::Full || task == StableDiffusionTask::ImageToImage {
905 let sd = SDBuidler::new_with_standalone_model(Task::ImageToImage, model_file.as_ref())
906 .map_err(|e| {
907 let err_msg =
908 format!("Failed to initialize the stable diffusion context. Reason: {e}");
909
910 #[cfg(feature = "logging")]
911 error!(target: "stdout", "{err_msg}");
912
913 LlamaCoreError::InitContext(err_msg)
914 })?
915 .with_vae_path(vae.as_ref())
916 .map_err(|e| {
917 let err_msg =
918 format!("Failed to initialize the stable diffusion context. Reason: {e}");
919
920 #[cfg(feature = "logging")]
921 error!(target: "stdout", "{err_msg}");
922
923 LlamaCoreError::InitContext(err_msg)
924 })?
925 .with_clip_l_path(clip_l.as_ref())
926 .map_err(|e| {
927 let err_msg =
928 format!("Failed to initialize the stable diffusion context. Reason: {e}");
929
930 #[cfg(feature = "logging")]
931 error!(target: "stdout", "{err_msg}");
932
933 LlamaCoreError::InitContext(err_msg)
934 })?
935 .with_t5xxl_path(t5xxl.as_ref())
936 .map_err(|e| {
937 let err_msg =
938 format!("Failed to initialize the stable diffusion context. Reason: {e}");
939
940 #[cfg(feature = "logging")]
941 error!(target: "stdout", "{err_msg}");
942
943 LlamaCoreError::InitContext(err_msg)
944 })?
945 .with_lora_model_dir(lora_model_dir.unwrap_or_default())
946 .map_err(|e| {
947 let err_msg =
948 format!("Failed to initialize the stable diffusion context. Reason: {e}");
949
950 #[cfg(feature = "logging")]
951 error!(target: "stdout", "{err_msg}");
952
953 LlamaCoreError::InitContext(err_msg)
954 })?
955 .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
956 .map_err(|e| {
957 let err_msg =
958 format!("Failed to initialize the stable diffusion context. Reason: {e}");
959
960 #[cfg(feature = "logging")]
961 error!(target: "stdout", "{err_msg}");
962
963 LlamaCoreError::InitContext(err_msg)
964 })?
965 .clip_on_cpu(clip_on_cpu)
966 .vae_on_cpu(vae_on_cpu)
967 .with_n_threads(n_threads)
968 .build();
969
970 #[cfg(feature = "logging")]
971 info!(target: "stdout", "sd: {:?}", &sd);
972
973 let ctx = sd.create_context().map_err(|e| {
974 let err_msg = format!("Fail to create the context. {e}");
975
976 #[cfg(feature = "logging")]
977 error!(target: "stdout", "{}", &err_msg);
978
979 LlamaCoreError::InitContext(err_msg)
980 })?;
981
982 let ctx = match ctx {
983 Context::ImageToImage(ctx) => ctx,
984 _ => {
985 let err_msg = "Fail to get the context for the image-to-image task";
986
987 #[cfg(feature = "logging")]
988 error!(target: "stdout", "{err_msg}");
989
990 return Err(LlamaCoreError::InitContext(err_msg.into()));
991 }
992 };
993
994 #[cfg(feature = "logging")]
995 info!(target: "stdout", "sd image_to_image context: {:?}", &ctx);
996
997 SD_IMAGE_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
998 let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_IMAGE_TO_IMAGE` has already been initialized";
999
1000 #[cfg(feature = "logging")]
1001 error!(target: "stdout", "{err_msg}");
1002
1003 LlamaCoreError::InitContext(err_msg.into())
1004 })?;
1005
1006 #[cfg(feature = "logging")]
1007 info!(target: "stdout", "The stable diffusion image-to-image context has been initialized");
1008 }
1009
1010 Ok(())
1011}
1012
1013#[derive(Clone, Debug, Copy, PartialEq, Eq)]
1015pub enum StableDiffusionTask {
1016 TextToImage,
1018 ImageToImage,
1020 Full,
1022}
1023
1024#[cfg(feature = "whisper")]
1026pub fn init_whisper_context(whisper_metadata: &WhisperMetadata) -> Result<(), LlamaCoreError> {
1027 let graph = GraphBuilder::new(EngineType::Whisper)?
1029 .with_config(whisper_metadata.clone())?
1030 .use_cpu()
1031 .build_from_files([&whisper_metadata.model_path])?;
1032
1033 match AUDIO_GRAPH.get() {
1034 Some(mutex_graph) => {
1035 #[cfg(feature = "logging")]
1036 info!(target: "stdout", "Re-initialize the audio context");
1037
1038 match mutex_graph.lock() {
1039 Ok(mut locked_graph) => *locked_graph = graph,
1040 Err(e) => {
1041 let err_msg = format!("Failed to lock the graph. Reason: {e}");
1042
1043 #[cfg(feature = "logging")]
1044 error!(target: "stdout", "{err_msg}");
1045
1046 return Err(LlamaCoreError::InitContext(err_msg));
1047 }
1048 }
1049 }
1050 None => {
1051 #[cfg(feature = "logging")]
1052 info!(target: "stdout", "Initialize the audio context");
1053
1054 AUDIO_GRAPH.set(Mutex::new(graph)).map_err(|_| {
1055 let err_msg = "Failed to initialize the audio context. Reason: The `AUDIO_GRAPH` has already been initialized";
1056
1057 #[cfg(feature = "logging")]
1058 error!(target: "stdout", "{err_msg}");
1059
1060 LlamaCoreError::InitContext(err_msg.into())
1061 })?;
1062 }
1063 }
1064
1065 #[cfg(feature = "logging")]
1066 info!(target: "stdout", "The audio context has been initialized");
1067
1068 Ok(())
1069}
1070
1071pub fn init_piper_context(
1082 piper_metadata: &PiperMetadata,
1083 voice_model: impl AsRef<Path>,
1084 voice_config: impl AsRef<Path>,
1085 espeak_ng_data: impl AsRef<Path>,
1086) -> Result<(), LlamaCoreError> {
1087 #[cfg(feature = "logging")]
1088 info!(target: "stdout", "Initializing the piper context");
1089
1090 let config = serde_json::json!({
1091 "model": voice_model.as_ref().to_owned(),
1092 "config": voice_config.as_ref().to_owned(),
1093 "espeak_data": espeak_ng_data.as_ref().to_owned(),
1094 });
1095
1096 let graph = GraphBuilder::new(EngineType::Piper)?
1098 .with_config(piper_metadata.clone())?
1099 .use_cpu()
1100 .build_from_buffer([config.to_string()])?;
1101
1102 PIPER_GRAPH.set(Mutex::new(graph)).map_err(|_| {
1103 let err_msg = "Failed to initialize the piper context. Reason: The `PIPER_GRAPH` has already been initialized";
1104
1105 #[cfg(feature = "logging")]
1106 error!(target: "stdout", "{err_msg}");
1107
1108 LlamaCoreError::InitContext(err_msg.into())
1109 })?;
1110
1111 #[cfg(feature = "logging")]
1112 info!(target: "stdout", "The piper context has been initialized");
1113
1114 Ok(())
1115}