1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
4
5#[cfg(feature = "logging")]
6#[macro_use]
7extern crate log;
8
9pub mod audio;
10pub mod chat;
11pub mod completions;
12pub mod embeddings;
13pub mod error;
14pub mod files;
15pub mod graph;
16pub mod images;
17pub mod metadata;
18pub mod models;
19#[cfg(feature = "rag")]
20#[cfg_attr(docsrs, doc(cfg(feature = "rag")))]
21pub mod rag;
22#[cfg(feature = "search")]
23#[cfg_attr(docsrs, doc(cfg(feature = "search")))]
24pub mod search;
25pub mod tts;
26pub mod utils;
27
28pub use error::LlamaCoreError;
29pub use graph::{EngineType, Graph, GraphBuilder};
30#[cfg(feature = "whisper")]
31use metadata::whisper::WhisperMetadata;
32pub use metadata::{
33 ggml::{GgmlMetadata, GgmlTtsMetadata},
34 piper::PiperMetadata,
35 BaseMetadata,
36};
37use once_cell::sync::OnceCell;
38use std::{
39 collections::HashMap,
40 path::Path,
41 sync::{Mutex, RwLock},
42};
43use utils::{get_output_buffer, RunningMode};
44use wasmedge_stable_diffusion::*;
45
46pub(crate) static CHAT_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlMetadata>>>> =
48 OnceCell::new();
49pub(crate) static EMBEDDING_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlMetadata>>>> =
51 OnceCell::new();
52pub(crate) static TTS_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlTtsMetadata>>>> =
54 OnceCell::new();
55pub(crate) static CACHED_UTF8_ENCODINGS: OnceCell<Mutex<Vec<u8>>> = OnceCell::new();
57pub(crate) static RUNNING_MODE: OnceCell<RwLock<RunningMode>> = OnceCell::new();
59pub(crate) static SD_TEXT_TO_IMAGE: OnceCell<Mutex<TextToImage>> = OnceCell::new();
61pub(crate) static SD_IMAGE_TO_IMAGE: OnceCell<Mutex<ImageToImage>> = OnceCell::new();
63#[cfg(feature = "whisper")]
65pub(crate) static AUDIO_GRAPH: OnceCell<Mutex<Graph<WhisperMetadata>>> = OnceCell::new();
66pub(crate) static PIPER_GRAPH: OnceCell<Mutex<Graph<PiperMetadata>>> = OnceCell::new();
68
69pub(crate) const MAX_BUFFER_SIZE: usize = 2usize.pow(14) * 15 + 128;
70pub(crate) const OUTPUT_TENSOR: usize = 0;
71const PLUGIN_VERSION: usize = 1;
72
73pub const ARCHIVES_DIR: &str = "archives";
75
76pub fn init_ggml_chat_context(metadata_for_chats: &[GgmlMetadata]) -> Result<(), LlamaCoreError> {
78 #[cfg(feature = "logging")]
79 info!(target: "stdout", "Initializing the core context");
80
81 if metadata_for_chats.is_empty() {
82 let err_msg = "The metadata for chat models is empty";
83
84 #[cfg(feature = "logging")]
85 error!(target: "stdout", "{}", err_msg);
86
87 return Err(LlamaCoreError::InitContext(err_msg.into()));
88 }
89
90 let mut chat_graphs = HashMap::new();
91 for metadata in metadata_for_chats {
92 let graph = Graph::new(metadata.clone())?;
93
94 chat_graphs.insert(graph.name().to_string(), graph);
95 }
96 CHAT_GRAPHS.set(Mutex::new(chat_graphs)).map_err(|_| {
97 let err_msg = "Failed to initialize the core context. Reason: The `CHAT_GRAPHS` has already been initialized";
98
99 #[cfg(feature = "logging")]
100 error!(target: "stdout", "{}", err_msg);
101
102 LlamaCoreError::InitContext(err_msg.into())
103 })?;
104
105 let running_mode = RunningMode::CHAT;
107 match RUNNING_MODE.get() {
108 Some(mode) => {
109 let mut mode = mode.write().unwrap();
110 *mode |= running_mode;
111 }
112 None => {
113 RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
114 let err_msg = "Failed to initialize the chat context. Reason: The `RUNNING_MODE` has already been initialized";
115
116 #[cfg(feature = "logging")]
117 error!(target: "stdout", "{}", err_msg);
118
119 LlamaCoreError::InitContext(err_msg.into())
120 })?;
121 }
122 }
123
124 Ok(())
125}
126
127pub fn init_ggml_embeddings_context(
129 metadata_for_embeddings: &[GgmlMetadata],
130) -> Result<(), LlamaCoreError> {
131 #[cfg(feature = "logging")]
132 info!(target: "stdout", "Initializing the embeddings context");
133
134 if metadata_for_embeddings.is_empty() {
135 let err_msg = "The metadata for chat models is empty";
136
137 #[cfg(feature = "logging")]
138 error!(target: "stdout", "{}", err_msg);
139
140 return Err(LlamaCoreError::InitContext(err_msg.into()));
141 }
142
143 let mut embedding_graphs = HashMap::new();
144 for metadata in metadata_for_embeddings {
145 let graph = Graph::new(metadata.clone())?;
146
147 embedding_graphs.insert(graph.name().to_string(), graph);
148 }
149 EMBEDDING_GRAPHS
150 .set(Mutex::new(embedding_graphs))
151 .map_err(|_| {
152 let err_msg = "Failed to initialize the core context. Reason: The `EMBEDDING_GRAPHS` has already been initialized";
153
154 #[cfg(feature = "logging")]
155 error!(target: "stdout", "{}", err_msg);
156
157 LlamaCoreError::InitContext(err_msg.into())
158 })?;
159
160 let running_mode = RunningMode::EMBEDDINGS;
162 match RUNNING_MODE.get() {
163 Some(mode) => {
164 let mut mode = mode.write().unwrap();
165 *mode |= running_mode;
166 }
167 None => {
168 RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
169 let err_msg = "Failed to initialize the embeddings context. Reason: The `RUNNING_MODE` has already been initialized";
170
171 #[cfg(feature = "logging")]
172 error!(target: "stdout", "{}", err_msg);
173
174 LlamaCoreError::InitContext(err_msg.into())
175 })?;
176 }
177 }
178
179 Ok(())
180}
181
182#[cfg(feature = "rag")]
184pub fn init_ggml_rag_context(
185 metadata_for_chats: &[GgmlMetadata],
186 metadata_for_embeddings: &[GgmlMetadata],
187) -> Result<(), LlamaCoreError> {
188 #[cfg(feature = "logging")]
189 info!(target: "stdout", "Initializing the core context for RAG scenarios");
190
191 if metadata_for_chats.is_empty() {
193 let err_msg = "The metadata for chat models is empty";
194
195 #[cfg(feature = "logging")]
196 error!(target: "stdout", "{}", err_msg);
197
198 return Err(LlamaCoreError::InitContext(err_msg.into()));
199 }
200 let mut chat_graphs = HashMap::new();
201 for metadata in metadata_for_chats {
202 let graph = Graph::new(metadata.clone())?;
203
204 chat_graphs.insert(graph.name().to_string(), graph);
205 }
206 CHAT_GRAPHS.set(Mutex::new(chat_graphs)).map_err(|_| {
207 let err_msg = "Failed to initialize the core context. Reason: The `CHAT_GRAPHS` has already been initialized";
208
209 #[cfg(feature = "logging")]
210 error!(target: "stdout", "{}", err_msg);
211
212 LlamaCoreError::InitContext(err_msg.into())
213 })?;
214
215 if metadata_for_embeddings.is_empty() {
217 let err_msg = "The metadata for embeddings is empty";
218
219 #[cfg(feature = "logging")]
220 error!(target: "stdout", "{}", err_msg);
221
222 return Err(LlamaCoreError::InitContext(err_msg.into()));
223 }
224 let mut embedding_graphs = HashMap::new();
225 for metadata in metadata_for_embeddings {
226 let graph = Graph::new(metadata.clone())?;
227
228 embedding_graphs.insert(graph.name().to_string(), graph);
229 }
230 EMBEDDING_GRAPHS
231 .set(Mutex::new(embedding_graphs))
232 .map_err(|_| {
233 let err_msg = "Failed to initialize the core context. Reason: The `EMBEDDING_GRAPHS` has already been initialized";
234
235 #[cfg(feature = "logging")]
236 error!(target: "stdout", "{}", err_msg);
237
238 LlamaCoreError::InitContext(err_msg.into())
239 })?;
240
241 let running_mode = RunningMode::RAG;
243 match RUNNING_MODE.get() {
244 Some(mode) => {
245 let mut mode = mode.write().unwrap();
246 *mode |= running_mode;
247 }
248 None => {
249 RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
250 let err_msg = "Failed to initialize the rag context. Reason: The `RUNNING_MODE` has already been initialized";
251
252 #[cfg(feature = "logging")]
253 error!(target: "stdout", "{}", err_msg);
254
255 LlamaCoreError::InitContext(err_msg.into())
256 })?;
257 }
258 }
259
260 Ok(())
261}
262
263pub fn init_ggml_tts_context(metadata_for_tts: &[GgmlTtsMetadata]) -> Result<(), LlamaCoreError> {
265 #[cfg(feature = "logging")]
266 info!(target: "stdout", "Initializing the TTS context");
267
268 if metadata_for_tts.is_empty() {
269 let err_msg = "The metadata for tts models is empty";
270
271 #[cfg(feature = "logging")]
272 error!(target: "stdout", "{}", err_msg);
273
274 return Err(LlamaCoreError::InitContext(err_msg.into()));
275 }
276
277 let mut tts_graphs = HashMap::new();
278 for metadata in metadata_for_tts {
279 let graph = Graph::new(metadata.clone())?;
280
281 tts_graphs.insert(graph.name().to_string(), graph);
282 }
283 TTS_GRAPHS.set(Mutex::new(tts_graphs)).map_err(|_| {
284 let err_msg = "Failed to initialize the core context. Reason: The `TTS_GRAPHS` has already been initialized";
285
286 #[cfg(feature = "logging")]
287 error!(target: "stdout", "{}", err_msg);
288
289 LlamaCoreError::InitContext(err_msg.into())
290 })?;
291
292 let running_mode = RunningMode::TTS;
294 match RUNNING_MODE.get() {
295 Some(mode) => {
296 let mut mode = mode.write().unwrap();
297 *mode |= running_mode;
298 }
299 None => {
300 RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
301 let err_msg = "Failed to initialize the embeddings context. Reason: The `RUNNING_MODE` has already been initialized";
302
303 #[cfg(feature = "logging")]
304 error!(target: "stdout", "{}", err_msg);
305
306 LlamaCoreError::InitContext(err_msg.into())
307 })?;
308 }
309 }
310
311 Ok(())
312}
313
314pub fn get_plugin_info() -> Result<PluginInfo, LlamaCoreError> {
318 #[cfg(feature = "logging")]
319 debug!(target: "stdout", "Getting the plugin info");
320
321 let running_mode = running_mode()?;
322
323 if running_mode.contains(RunningMode::CHAT) || running_mode.contains(RunningMode::RAG) {
324 let chat_graphs = match CHAT_GRAPHS.get() {
325 Some(chat_graphs) => chat_graphs,
326 None => {
327 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
328
329 #[cfg(feature = "logging")]
330 error!(target: "stdout", "{}", err_msg);
331
332 return Err(LlamaCoreError::Operation(err_msg.into()));
333 }
334 };
335
336 let chat_graphs = chat_graphs.lock().map_err(|e| {
337 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
338
339 #[cfg(feature = "logging")]
340 error!(target: "stdout", "{}", &err_msg);
341
342 LlamaCoreError::Operation(err_msg)
343 })?;
344
345 let graph = match chat_graphs.values().next() {
346 Some(graph) => graph,
347 None => {
348 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
349
350 #[cfg(feature = "logging")]
351 error!(target: "stdout", "{}", err_msg);
352
353 return Err(LlamaCoreError::Operation(err_msg.into()));
354 }
355 };
356
357 get_plugin_info_by_graph(graph)
358 } else if running_mode.contains(RunningMode::EMBEDDINGS) {
359 let embedding_graphs = match EMBEDDING_GRAPHS.get() {
360 Some(embedding_graphs) => embedding_graphs,
361 None => {
362 let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
363
364 #[cfg(feature = "logging")]
365 error!(target: "stdout", "{}", err_msg);
366
367 return Err(LlamaCoreError::Operation(err_msg.into()));
368 }
369 };
370
371 let embedding_graphs = embedding_graphs.lock().map_err(|e| {
372 let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {}", e);
373
374 #[cfg(feature = "logging")]
375 error!(target: "stdout", "{}", &err_msg);
376
377 LlamaCoreError::Operation(err_msg)
378 })?;
379
380 let graph = match embedding_graphs.values().next() {
381 Some(graph) => graph,
382 None => {
383 let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
384
385 #[cfg(feature = "logging")]
386 error!(target: "stdout", "{}", err_msg);
387
388 return Err(LlamaCoreError::Operation(err_msg.into()));
389 }
390 };
391
392 get_plugin_info_by_graph(graph)
393 } else if running_mode.contains(RunningMode::TTS) {
394 let tts_graphs = match TTS_GRAPHS.get() {
395 Some(tts_graphs) => tts_graphs,
396 None => {
397 let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
398
399 #[cfg(feature = "logging")]
400 error!(target: "stdout", "{}", err_msg);
401
402 return Err(LlamaCoreError::Operation(err_msg.into()));
403 }
404 };
405
406 let tts_graphs = tts_graphs.lock().map_err(|e| {
407 let err_msg = format!("Fail to acquire the lock of `TTS_GRAPHS`. {}", e);
408
409 #[cfg(feature = "logging")]
410 error!(target: "stdout", "{}", &err_msg);
411
412 LlamaCoreError::Operation(err_msg)
413 })?;
414
415 let graph = match tts_graphs.values().next() {
416 Some(graph) => graph,
417 None => {
418 let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
419
420 #[cfg(feature = "logging")]
421 error!(target: "stdout", "{}", err_msg);
422
423 return Err(LlamaCoreError::Operation(err_msg.into()));
424 }
425 };
426
427 get_plugin_info_by_graph(graph)
428 } else {
429 let err_msg = "RUNNING_MODE is not set";
430
431 #[cfg(feature = "logging")]
432 error!(target: "stdout", "{}", err_msg);
433
434 Err(LlamaCoreError::Operation(err_msg.into()))
435 }
436}
437
438fn get_plugin_info_by_graph<M: BaseMetadata + serde::Serialize + Clone + Default>(
439 graph: &Graph<M>,
440) -> Result<PluginInfo, LlamaCoreError> {
441 #[cfg(feature = "logging")]
442 debug!(target: "stdout", "Getting the plugin info by the graph named {}", graph.name());
443
444 let output_buffer = get_output_buffer(graph, PLUGIN_VERSION)?;
446 let metadata: serde_json::Value = serde_json::from_slice(&output_buffer[..]).map_err(|e| {
447 let err_msg = format!("Fail to deserialize the plugin metadata. {}", e);
448
449 #[cfg(feature = "logging")]
450 error!(target: "stdout", "{}", &err_msg);
451
452 LlamaCoreError::Operation(err_msg)
453 })?;
454
455 let plugin_build_number = match metadata.get("llama_build_number") {
457 Some(value) => match value.as_u64() {
458 Some(number) => number,
459 None => {
460 let err_msg = "Failed to convert the build number of the plugin to u64";
461
462 #[cfg(feature = "logging")]
463 error!(target: "stdout", "{}", err_msg);
464
465 return Err(LlamaCoreError::Operation(err_msg.into()));
466 }
467 },
468 None => {
469 let err_msg = "Metadata does not have the field `llama_build_number`.";
470
471 #[cfg(feature = "logging")]
472 error!(target: "stdout", "{}", err_msg);
473
474 return Err(LlamaCoreError::Operation(err_msg.into()));
475 }
476 };
477
478 let plugin_commit = match metadata.get("llama_commit") {
480 Some(value) => match value.as_str() {
481 Some(commit) => commit,
482 None => {
483 let err_msg = "Failed to convert the commit id of the plugin to string";
484
485 #[cfg(feature = "logging")]
486 error!(target: "stdout", "{}", err_msg);
487
488 return Err(LlamaCoreError::Operation(err_msg.into()));
489 }
490 },
491 None => {
492 let err_msg = "Metadata does not have the field `llama_commit`.";
493
494 #[cfg(feature = "logging")]
495 error!(target: "stdout", "{}", err_msg);
496
497 return Err(LlamaCoreError::Operation(err_msg.into()));
498 }
499 };
500
501 #[cfg(feature = "logging")]
502 debug!(target: "stdout", "Plugin info: b{}(commit {})", plugin_build_number, plugin_commit);
503
504 Ok(PluginInfo {
505 build_number: plugin_build_number,
506 commit_id: plugin_commit.to_string(),
507 })
508}
509
510#[derive(Debug, Clone)]
512pub struct PluginInfo {
513 pub build_number: u64,
514 pub commit_id: String,
515}
516impl std::fmt::Display for PluginInfo {
517 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
518 write!(
519 f,
520 "wasinn-ggml plugin: b{}(commit {})",
521 self.build_number, self.commit_id
522 )
523 }
524}
525
526pub fn running_mode() -> Result<RunningMode, LlamaCoreError> {
528 #[cfg(feature = "logging")]
529 debug!(target: "stdout", "Get the running mode.");
530
531 match RUNNING_MODE.get() {
532 Some(mode) => match mode.read() {
533 Ok(mode) => Ok(*mode),
534 Err(e) => {
535 let err_msg = format!("Fail to get the underlying value of `RUNNING_MODE`. {}", e);
536
537 #[cfg(feature = "logging")]
538 error!(target: "stdout", "{}", err_msg);
539
540 Err(LlamaCoreError::Operation(err_msg))
541 }
542 },
543 None => {
544 let err_msg = "Fail to get the underlying value of `RUNNING_MODE`.";
545
546 #[cfg(feature = "logging")]
547 error!(target: "stdout", "{}", err_msg);
548
549 Err(LlamaCoreError::Operation(err_msg.into()))
550 }
551 }
552}
553
554#[allow(clippy::too_many_arguments)]
574pub fn init_sd_context_with_full_model(
575 model_file: impl AsRef<str>,
576 lora_model_dir: Option<&str>,
577 controlnet_path: Option<&str>,
578 controlnet_on_cpu: bool,
579 clip_on_cpu: bool,
580 vae_on_cpu: bool,
581 n_threads: i32,
582 task: StableDiffusionTask,
583) -> Result<(), LlamaCoreError> {
584 #[cfg(feature = "logging")]
585 info!(target: "stdout", "Initializing the stable diffusion context with the full model");
586
587 let control_net_on_cpu = match controlnet_path {
588 Some(path) if !path.is_empty() => controlnet_on_cpu,
589 _ => false,
590 };
591
592 if task == StableDiffusionTask::Full || task == StableDiffusionTask::TextToImage {
594 let sd = SDBuidler::new(Task::TextToImage, model_file.as_ref())
595 .map_err(|e| {
596 let err_msg = format!(
597 "Failed to initialize the stable diffusion context. Reason: {}",
598 e
599 );
600
601 #[cfg(feature = "logging")]
602 error!(target: "stdout", "{}", err_msg);
603
604 LlamaCoreError::InitContext(err_msg)
605 })?
606 .with_lora_model_dir(lora_model_dir.unwrap_or_default())
607 .map_err(|e| {
608 let err_msg = format!(
609 "Failed to initialize the stable diffusion context. Reason: {}",
610 e
611 );
612
613 #[cfg(feature = "logging")]
614 error!(target: "stdout", "{}", err_msg);
615
616 LlamaCoreError::InitContext(err_msg)
617 })?
618 .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
619 .map_err(|e| {
620 let err_msg = format!(
621 "Failed to initialize the stable diffusion context. Reason: {}",
622 e
623 );
624
625 #[cfg(feature = "logging")]
626 error!(target: "stdout", "{}", err_msg);
627
628 LlamaCoreError::InitContext(err_msg)
629 })?
630 .clip_on_cpu(clip_on_cpu)
631 .vae_on_cpu(vae_on_cpu)
632 .with_n_threads(n_threads)
633 .build();
634
635 #[cfg(feature = "logging")]
636 info!(target: "stdout", "sd: {:?}", &sd);
637
638 let ctx = sd.create_context().map_err(|e| {
639 let err_msg = format!("Fail to create the context. {}", e);
640
641 #[cfg(feature = "logging")]
642 error!(target: "stdout", "{}", &err_msg);
643
644 LlamaCoreError::InitContext(err_msg)
645 })?;
646
647 let ctx = match ctx {
648 Context::TextToImage(ctx) => ctx,
649 _ => {
650 let err_msg = "Fail to get the context for the text-to-image task";
651
652 #[cfg(feature = "logging")]
653 error!(target: "stdout", "{}", err_msg);
654
655 return Err(LlamaCoreError::InitContext(err_msg.into()));
656 }
657 };
658
659 #[cfg(feature = "logging")]
660 info!(target: "stdout", "sd text_to_image context: {:?}", &ctx);
661
662 SD_TEXT_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
663 let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_TEXT_TO_IMAGE` has already been initialized";
664
665 #[cfg(feature = "logging")]
666 error!(target: "stdout", "{}", err_msg);
667
668 LlamaCoreError::InitContext(err_msg.into())
669 })?;
670
671 #[cfg(feature = "logging")]
672 info!(target: "stdout", "The stable diffusion text-to-image context has been initialized");
673 }
674
675 if task == StableDiffusionTask::Full || task == StableDiffusionTask::ImageToImage {
677 let sd = SDBuidler::new(Task::ImageToImage, model_file.as_ref())
678 .map_err(|e| {
679 let err_msg = format!(
680 "Failed to initialize the stable diffusion context. Reason: {}",
681 e
682 );
683
684 #[cfg(feature = "logging")]
685 error!(target: "stdout", "{}", err_msg);
686
687 LlamaCoreError::InitContext(err_msg)
688 })?
689 .with_lora_model_dir(lora_model_dir.unwrap_or_default())
690 .map_err(|e| {
691 let err_msg = format!(
692 "Failed to initialize the stable diffusion context. Reason: {}",
693 e
694 );
695
696 #[cfg(feature = "logging")]
697 error!(target: "stdout", "{}", err_msg);
698
699 LlamaCoreError::InitContext(err_msg)
700 })?
701 .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
702 .map_err(|e| {
703 let err_msg = format!(
704 "Failed to initialize the stable diffusion context. Reason: {}",
705 e
706 );
707
708 #[cfg(feature = "logging")]
709 error!(target: "stdout", "{}", err_msg);
710
711 LlamaCoreError::InitContext(err_msg)
712 })?
713 .clip_on_cpu(clip_on_cpu)
714 .vae_on_cpu(vae_on_cpu)
715 .with_n_threads(n_threads)
716 .build();
717
718 #[cfg(feature = "logging")]
719 info!(target: "stdout", "sd: {:?}", &sd);
720
721 let ctx = sd.create_context().map_err(|e| {
722 let err_msg = format!("Fail to create the context. {}", e);
723
724 #[cfg(feature = "logging")]
725 error!(target: "stdout", "{}", &err_msg);
726
727 LlamaCoreError::InitContext(err_msg)
728 })?;
729
730 let ctx = match ctx {
731 Context::ImageToImage(ctx) => ctx,
732 _ => {
733 let err_msg = "Fail to get the context for the image-to-image task";
734
735 #[cfg(feature = "logging")]
736 error!(target: "stdout", "{}", err_msg);
737
738 return Err(LlamaCoreError::InitContext(err_msg.into()));
739 }
740 };
741
742 #[cfg(feature = "logging")]
743 info!(target: "stdout", "sd image_to_image context: {:?}", &ctx);
744
745 SD_IMAGE_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
746 let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_IMAGE_TO_IMAGE` has already been initialized";
747
748 #[cfg(feature = "logging")]
749 error!(target: "stdout", "{}", err_msg);
750
751 LlamaCoreError::InitContext(err_msg.into())
752 })?;
753
754 #[cfg(feature = "logging")]
755 info!(target: "stdout", "The stable diffusion image-to-image context has been initialized");
756 }
757
758 Ok(())
759}
760
761#[allow(clippy::too_many_arguments)]
787pub fn init_sd_context_with_standalone_model(
788 model_file: impl AsRef<str>,
789 vae: impl AsRef<str>,
790 clip_l: impl AsRef<str>,
791 t5xxl: impl AsRef<str>,
792 lora_model_dir: Option<&str>,
793 controlnet_path: Option<&str>,
794 controlnet_on_cpu: bool,
795 clip_on_cpu: bool,
796 vae_on_cpu: bool,
797 n_threads: i32,
798 task: StableDiffusionTask,
799) -> Result<(), LlamaCoreError> {
800 #[cfg(feature = "logging")]
801 info!(target: "stdout", "Initializing the stable diffusion context with the standalone diffusion model");
802
803 let control_net_on_cpu = match controlnet_path {
804 Some(path) if !path.is_empty() => controlnet_on_cpu,
805 _ => false,
806 };
807
808 if task == StableDiffusionTask::Full || task == StableDiffusionTask::TextToImage {
810 let sd = SDBuidler::new_with_standalone_model(Task::TextToImage, model_file.as_ref())
811 .map_err(|e| {
812 let err_msg = format!(
813 "Failed to initialize the stable diffusion context. Reason: {}",
814 e
815 );
816
817 #[cfg(feature = "logging")]
818 error!(target: "stdout", "{}", err_msg);
819
820 LlamaCoreError::InitContext(err_msg)
821 })?
822 .with_vae_path(vae.as_ref())
823 .map_err(|e| {
824 let err_msg = format!(
825 "Failed to initialize the stable diffusion context. Reason: {}",
826 e
827 );
828
829 #[cfg(feature = "logging")]
830 error!(target: "stdout", "{}", err_msg);
831
832 LlamaCoreError::InitContext(err_msg)
833 })?
834 .with_clip_l_path(clip_l.as_ref())
835 .map_err(|e| {
836 let err_msg = format!(
837 "Failed to initialize the stable diffusion context. Reason: {}",
838 e
839 );
840
841 #[cfg(feature = "logging")]
842 error!(target: "stdout", "{}", err_msg);
843
844 LlamaCoreError::InitContext(err_msg)
845 })?
846 .with_t5xxl_path(t5xxl.as_ref())
847 .map_err(|e| {
848 let err_msg = format!(
849 "Failed to initialize the stable diffusion context. Reason: {}",
850 e
851 );
852
853 #[cfg(feature = "logging")]
854 error!(target: "stdout", "{}", err_msg);
855
856 LlamaCoreError::InitContext(err_msg)
857 })?
858 .with_lora_model_dir(lora_model_dir.unwrap_or_default())
859 .map_err(|e| {
860 let err_msg = format!(
861 "Failed to initialize the stable diffusion context. Reason: {}",
862 e
863 );
864
865 #[cfg(feature = "logging")]
866 error!(target: "stdout", "{}", err_msg);
867
868 LlamaCoreError::InitContext(err_msg)
869 })?
870 .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
871 .map_err(|e| {
872 let err_msg = format!(
873 "Failed to initialize the stable diffusion context. Reason: {}",
874 e
875 );
876
877 #[cfg(feature = "logging")]
878 error!(target: "stdout", "{}", err_msg);
879
880 LlamaCoreError::InitContext(err_msg)
881 })?
882 .clip_on_cpu(clip_on_cpu)
883 .vae_on_cpu(vae_on_cpu)
884 .with_n_threads(n_threads)
885 .build();
886
887 #[cfg(feature = "logging")]
888 info!(target: "stdout", "sd: {:?}", &sd);
889
890 let ctx = sd.create_context().map_err(|e| {
891 let err_msg = format!("Fail to create the context. {}", e);
892
893 #[cfg(feature = "logging")]
894 error!(target: "stdout", "{}", &err_msg);
895
896 LlamaCoreError::InitContext(err_msg)
897 })?;
898
899 let ctx = match ctx {
900 Context::TextToImage(ctx) => ctx,
901 _ => {
902 let err_msg = "Fail to get the context for the text-to-image task";
903
904 #[cfg(feature = "logging")]
905 error!(target: "stdout", "{}", err_msg);
906
907 return Err(LlamaCoreError::InitContext(err_msg.into()));
908 }
909 };
910
911 #[cfg(feature = "logging")]
912 info!(target: "stdout", "sd text_to_image context: {:?}", &ctx);
913
914 SD_TEXT_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
915 let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_TEXT_TO_IMAGE` has already been initialized";
916
917 #[cfg(feature = "logging")]
918 error!(target: "stdout", "{}", err_msg);
919
920 LlamaCoreError::InitContext(err_msg.into())
921 })?;
922
923 #[cfg(feature = "logging")]
924 info!(target: "stdout", "The stable diffusion text-to-image context has been initialized");
925 }
926
927 if task == StableDiffusionTask::Full || task == StableDiffusionTask::ImageToImage {
929 let sd = SDBuidler::new_with_standalone_model(Task::ImageToImage, model_file.as_ref())
930 .map_err(|e| {
931 let err_msg = format!(
932 "Failed to initialize the stable diffusion context. Reason: {}",
933 e
934 );
935
936 #[cfg(feature = "logging")]
937 error!(target: "stdout", "{}", err_msg);
938
939 LlamaCoreError::InitContext(err_msg)
940 })?
941 .with_vae_path(vae.as_ref())
942 .map_err(|e| {
943 let err_msg = format!(
944 "Failed to initialize the stable diffusion context. Reason: {}",
945 e
946 );
947
948 #[cfg(feature = "logging")]
949 error!(target: "stdout", "{}", err_msg);
950
951 LlamaCoreError::InitContext(err_msg)
952 })?
953 .with_clip_l_path(clip_l.as_ref())
954 .map_err(|e| {
955 let err_msg = format!(
956 "Failed to initialize the stable diffusion context. Reason: {}",
957 e
958 );
959
960 #[cfg(feature = "logging")]
961 error!(target: "stdout", "{}", err_msg);
962
963 LlamaCoreError::InitContext(err_msg)
964 })?
965 .with_t5xxl_path(t5xxl.as_ref())
966 .map_err(|e| {
967 let err_msg = format!(
968 "Failed to initialize the stable diffusion context. Reason: {}",
969 e
970 );
971
972 #[cfg(feature = "logging")]
973 error!(target: "stdout", "{}", err_msg);
974
975 LlamaCoreError::InitContext(err_msg)
976 })?
977 .with_lora_model_dir(lora_model_dir.unwrap_or_default())
978 .map_err(|e| {
979 let err_msg = format!(
980 "Failed to initialize the stable diffusion context. Reason: {}",
981 e
982 );
983
984 #[cfg(feature = "logging")]
985 error!(target: "stdout", "{}", err_msg);
986
987 LlamaCoreError::InitContext(err_msg)
988 })?
989 .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
990 .map_err(|e| {
991 let err_msg = format!(
992 "Failed to initialize the stable diffusion context. Reason: {}",
993 e
994 );
995
996 #[cfg(feature = "logging")]
997 error!(target: "stdout", "{}", err_msg);
998
999 LlamaCoreError::InitContext(err_msg)
1000 })?
1001 .clip_on_cpu(clip_on_cpu)
1002 .vae_on_cpu(vae_on_cpu)
1003 .with_n_threads(n_threads)
1004 .build();
1005
1006 #[cfg(feature = "logging")]
1007 info!(target: "stdout", "sd: {:?}", &sd);
1008
1009 let ctx = sd.create_context().map_err(|e| {
1010 let err_msg = format!("Fail to create the context. {}", e);
1011
1012 #[cfg(feature = "logging")]
1013 error!(target: "stdout", "{}", &err_msg);
1014
1015 LlamaCoreError::InitContext(err_msg)
1016 })?;
1017
1018 let ctx = match ctx {
1019 Context::ImageToImage(ctx) => ctx,
1020 _ => {
1021 let err_msg = "Fail to get the context for the image-to-image task";
1022
1023 #[cfg(feature = "logging")]
1024 error!(target: "stdout", "{}", err_msg);
1025
1026 return Err(LlamaCoreError::InitContext(err_msg.into()));
1027 }
1028 };
1029
1030 #[cfg(feature = "logging")]
1031 info!(target: "stdout", "sd image_to_image context: {:?}", &ctx);
1032
1033 SD_IMAGE_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
1034 let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_IMAGE_TO_IMAGE` has already been initialized";
1035
1036 #[cfg(feature = "logging")]
1037 error!(target: "stdout", "{}", err_msg);
1038
1039 LlamaCoreError::InitContext(err_msg.into())
1040 })?;
1041
1042 #[cfg(feature = "logging")]
1043 info!(target: "stdout", "The stable diffusion image-to-image context has been initialized");
1044 }
1045
1046 Ok(())
1047}
1048
1049#[derive(Clone, Debug, Copy, PartialEq, Eq)]
1051pub enum StableDiffusionTask {
1052 TextToImage,
1054 ImageToImage,
1056 Full,
1058}
1059
1060#[cfg(feature = "whisper")]
1062pub fn init_whisper_context(whisper_metadata: &WhisperMetadata) -> Result<(), LlamaCoreError> {
1063 let graph = GraphBuilder::new(EngineType::Whisper)?
1065 .with_config(whisper_metadata.clone())?
1066 .use_cpu()
1067 .build_from_files([&whisper_metadata.model_path])?;
1068
1069 match AUDIO_GRAPH.get() {
1070 Some(mutex_graph) => {
1071 #[cfg(feature = "logging")]
1072 info!(target: "stdout", "Re-initialize the audio context");
1073
1074 match mutex_graph.lock() {
1075 Ok(mut locked_graph) => *locked_graph = graph,
1076 Err(e) => {
1077 let err_msg = format!("Failed to lock the graph. Reason: {}", e);
1078
1079 #[cfg(feature = "logging")]
1080 error!(target: "stdout", "{}", err_msg);
1081
1082 return Err(LlamaCoreError::InitContext(err_msg));
1083 }
1084 }
1085 }
1086 None => {
1087 #[cfg(feature = "logging")]
1088 info!(target: "stdout", "Initialize the audio context");
1089
1090 AUDIO_GRAPH.set(Mutex::new(graph)).map_err(|_| {
1091 let err_msg = "Failed to initialize the audio context. Reason: The `AUDIO_GRAPH` has already been initialized";
1092
1093 #[cfg(feature = "logging")]
1094 error!(target: "stdout", "{}", err_msg);
1095
1096 LlamaCoreError::InitContext(err_msg.into())
1097 })?;
1098 }
1099 }
1100
1101 #[cfg(feature = "logging")]
1102 info!(target: "stdout", "The audio context has been initialized");
1103
1104 Ok(())
1105}
1106
1107pub fn init_piper_context(
1118 piper_metadata: &PiperMetadata,
1119 voice_model: impl AsRef<Path>,
1120 voice_config: impl AsRef<Path>,
1121 espeak_ng_data: impl AsRef<Path>,
1122) -> Result<(), LlamaCoreError> {
1123 #[cfg(feature = "logging")]
1124 info!(target: "stdout", "Initializing the piper context");
1125
1126 let config = serde_json::json!({
1127 "model": voice_model.as_ref().to_owned(),
1128 "config": voice_config.as_ref().to_owned(),
1129 "espeak_data": espeak_ng_data.as_ref().to_owned(),
1130 });
1131
1132 let graph = GraphBuilder::new(EngineType::Piper)?
1134 .with_config(piper_metadata.clone())?
1135 .use_cpu()
1136 .build_from_buffer([config.to_string()])?;
1137
1138 PIPER_GRAPH.set(Mutex::new(graph)).map_err(|_| {
1139 let err_msg = "Failed to initialize the piper context. Reason: The `PIPER_GRAPH` has already been initialized";
1140
1141 #[cfg(feature = "logging")]
1142 error!(target: "stdout", "{}", err_msg);
1143
1144 LlamaCoreError::InitContext(err_msg.into())
1145 })?;
1146
1147 #[cfg(feature = "logging")]
1148 info!(target: "stdout", "The piper context has been initialized");
1149
1150 Ok(())
1151}