1#![cfg_attr(docsrs, feature(doc_cfg))]
4
5#[cfg(feature = "logging")]
6#[macro_use]
7extern crate log;
8
9pub mod audio;
10pub mod chat;
11pub mod completions;
12pub mod embeddings;
13pub mod error;
14pub mod files;
15pub mod graph;
16pub mod images;
17pub mod metadata;
18pub mod models;
19pub mod tts;
20pub mod utils;
21
22pub use error::LlamaCoreError;
23pub use graph::{EngineType, Graph, GraphBuilder};
24#[cfg(feature = "whisper")]
25use metadata::whisper::WhisperMetadata;
26pub use metadata::{
27 ggml::{GgmlMetadata, GgmlTtsMetadata},
28 piper::PiperMetadata,
29 BaseMetadata,
30};
31use once_cell::sync::OnceCell;
32use std::{
33 collections::HashMap,
34 path::Path,
35 sync::{Mutex, RwLock},
36};
37use utils::{get_output_buffer, RunningMode};
38use wasmedge_stable_diffusion::*;
39
40pub(crate) static CHAT_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlMetadata>>>> =
42 OnceCell::new();
43pub(crate) static EMBEDDING_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlMetadata>>>> =
45 OnceCell::new();
46pub(crate) static TTS_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlTtsMetadata>>>> =
48 OnceCell::new();
49pub(crate) static RUNNING_MODE: OnceCell<RwLock<RunningMode>> = OnceCell::new();
51pub(crate) static SD_TEXT_TO_IMAGE: OnceCell<Mutex<TextToImage>> = OnceCell::new();
53pub(crate) static SD_IMAGE_TO_IMAGE: OnceCell<Mutex<ImageToImage>> = OnceCell::new();
55#[cfg(feature = "whisper")]
57pub(crate) static AUDIO_GRAPH: OnceCell<Mutex<Graph<WhisperMetadata>>> = OnceCell::new();
58pub(crate) static PIPER_GRAPH: OnceCell<Mutex<Graph<PiperMetadata>>> = OnceCell::new();
60
61pub(crate) const MAX_BUFFER_SIZE: usize = 2usize.pow(14) * 15 + 128;
62pub(crate) const OUTPUT_TENSOR: usize = 0;
63const PLUGIN_VERSION: usize = 1;
64
65pub const ARCHIVES_DIR: &str = "archives";
67
68pub fn init_ggml_chat_context(metadata_for_chats: &[GgmlMetadata]) -> Result<(), LlamaCoreError> {
70 #[cfg(feature = "logging")]
71 info!(target: "stdout", "Initializing the core context");
72
73 if metadata_for_chats.is_empty() {
74 let err_msg = "The metadata for chat models is empty";
75
76 #[cfg(feature = "logging")]
77 error!(target: "stdout", "{err_msg}");
78
79 return Err(LlamaCoreError::InitContext(err_msg.into()));
80 }
81
82 let mut chat_graphs = HashMap::new();
83 for metadata in metadata_for_chats {
84 let graph = Graph::new(metadata.clone())?;
85
86 chat_graphs.insert(graph.name().to_string(), graph);
87 }
88
89 for model_name in chat_graphs.keys() {
92 chat::chat_completions::get_or_create_model_lock(model_name)?;
93
94 #[cfg(feature = "logging")]
95 info!(target: "stdout", "Pre-created stream lock for model: {}", model_name);
96 }
97
98 CHAT_GRAPHS.set(Mutex::new(chat_graphs)).map_err(|_| {
99 let err_msg = "Failed to initialize the core context. Reason: The `CHAT_GRAPHS` has already been initialized";
100
101 #[cfg(feature = "logging")]
102 error!(target: "stdout", "{err_msg}");
103
104 LlamaCoreError::InitContext(err_msg.into())
105 })?;
106
107 let running_mode = RunningMode::CHAT;
109 match RUNNING_MODE.get() {
110 Some(mode) => {
111 let mut mode = mode.write().unwrap();
112 *mode |= running_mode;
113 }
114 None => {
115 RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
116 let err_msg = "Failed to initialize the chat context. Reason: The `RUNNING_MODE` has already been initialized";
117
118 #[cfg(feature = "logging")]
119 error!(target: "stdout", "{err_msg}");
120
121 LlamaCoreError::InitContext(err_msg.into())
122 })?;
123 }
124 }
125
126 Ok(())
127}
128
129pub fn init_ggml_embeddings_context(
131 metadata_for_embeddings: &[GgmlMetadata],
132) -> Result<(), LlamaCoreError> {
133 #[cfg(feature = "logging")]
134 info!(target: "stdout", "Initializing the embeddings context");
135
136 if metadata_for_embeddings.is_empty() {
137 let err_msg = "The metadata for chat models is empty";
138
139 #[cfg(feature = "logging")]
140 error!(target: "stdout", "{err_msg}");
141
142 return Err(LlamaCoreError::InitContext(err_msg.into()));
143 }
144
145 let mut embedding_graphs = HashMap::new();
146 for metadata in metadata_for_embeddings {
147 let graph = Graph::new(metadata.clone())?;
148
149 embedding_graphs.insert(graph.name().to_string(), graph);
150 }
151 EMBEDDING_GRAPHS
152 .set(Mutex::new(embedding_graphs))
153 .map_err(|_| {
154 let err_msg = "Failed to initialize the core context. Reason: The `EMBEDDING_GRAPHS` has already been initialized";
155
156 #[cfg(feature = "logging")]
157 error!(target: "stdout", "{err_msg}");
158
159 LlamaCoreError::InitContext(err_msg.into())
160 })?;
161
162 let running_mode = RunningMode::EMBEDDINGS;
164 match RUNNING_MODE.get() {
165 Some(mode) => {
166 let mut mode = mode.write().unwrap();
167 *mode |= running_mode;
168 }
169 None => {
170 RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
171 let err_msg = "Failed to initialize the embeddings context. Reason: The `RUNNING_MODE` has already been initialized";
172
173 #[cfg(feature = "logging")]
174 error!(target: "stdout", "{err_msg}");
175
176 LlamaCoreError::InitContext(err_msg.into())
177 })?;
178 }
179 }
180
181 Ok(())
182}
183
184pub fn init_ggml_tts_context(metadata_for_tts: &[GgmlTtsMetadata]) -> Result<(), LlamaCoreError> {
186 #[cfg(feature = "logging")]
187 info!(target: "stdout", "Initializing the TTS context");
188
189 if metadata_for_tts.is_empty() {
190 let err_msg = "The metadata for tts models is empty";
191
192 #[cfg(feature = "logging")]
193 error!(target: "stdout", "{err_msg}");
194
195 return Err(LlamaCoreError::InitContext(err_msg.into()));
196 }
197
198 let mut tts_graphs = HashMap::new();
199 for metadata in metadata_for_tts {
200 let graph = Graph::new(metadata.clone())?;
201
202 tts_graphs.insert(graph.name().to_string(), graph);
203 }
204 TTS_GRAPHS.set(Mutex::new(tts_graphs)).map_err(|_| {
205 let err_msg = "Failed to initialize the core context. Reason: The `TTS_GRAPHS` has already been initialized";
206
207 #[cfg(feature = "logging")]
208 error!(target: "stdout", "{err_msg}");
209
210 LlamaCoreError::InitContext(err_msg.into())
211 })?;
212
213 let running_mode = RunningMode::TTS;
215 match RUNNING_MODE.get() {
216 Some(mode) => {
217 let mut mode = mode.write().unwrap();
218 *mode |= running_mode;
219 }
220 None => {
221 RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
222 let err_msg = "Failed to initialize the embeddings context. Reason: The `RUNNING_MODE` has already been initialized";
223
224 #[cfg(feature = "logging")]
225 error!(target: "stdout", "{err_msg}");
226
227 LlamaCoreError::InitContext(err_msg.into())
228 })?;
229 }
230 }
231
232 Ok(())
233}
234
235pub fn get_plugin_info() -> Result<PluginInfo, LlamaCoreError> {
239 #[cfg(feature = "logging")]
240 debug!(target: "stdout", "Getting the plugin info");
241
242 let running_mode = running_mode()?;
243
244 if running_mode.contains(RunningMode::CHAT) || running_mode.contains(RunningMode::RAG) {
245 let chat_graphs = match CHAT_GRAPHS.get() {
246 Some(chat_graphs) => chat_graphs,
247 None => {
248 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
249
250 #[cfg(feature = "logging")]
251 error!(target: "stdout", "{err_msg}");
252
253 return Err(LlamaCoreError::Operation(err_msg.into()));
254 }
255 };
256
257 let chat_graphs = chat_graphs.lock().map_err(|e| {
258 let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
259
260 #[cfg(feature = "logging")]
261 error!(target: "stdout", "{}", &err_msg);
262
263 LlamaCoreError::Operation(err_msg)
264 })?;
265
266 let graph = match chat_graphs.values().next() {
267 Some(graph) => graph,
268 None => {
269 let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
270
271 #[cfg(feature = "logging")]
272 error!(target: "stdout", "{err_msg}");
273
274 return Err(LlamaCoreError::Operation(err_msg.into()));
275 }
276 };
277
278 get_plugin_info_by_graph(graph)
279 } else if running_mode.contains(RunningMode::EMBEDDINGS) {
280 let embedding_graphs = match EMBEDDING_GRAPHS.get() {
281 Some(embedding_graphs) => embedding_graphs,
282 None => {
283 let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
284
285 #[cfg(feature = "logging")]
286 error!(target: "stdout", "{err_msg}");
287
288 return Err(LlamaCoreError::Operation(err_msg.into()));
289 }
290 };
291
292 let embedding_graphs = embedding_graphs.lock().map_err(|e| {
293 let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {e}");
294
295 #[cfg(feature = "logging")]
296 error!(target: "stdout", "{}", &err_msg);
297
298 LlamaCoreError::Operation(err_msg)
299 })?;
300
301 let graph = match embedding_graphs.values().next() {
302 Some(graph) => graph,
303 None => {
304 let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
305
306 #[cfg(feature = "logging")]
307 error!(target: "stdout", "{err_msg}");
308
309 return Err(LlamaCoreError::Operation(err_msg.into()));
310 }
311 };
312
313 get_plugin_info_by_graph(graph)
314 } else if running_mode.contains(RunningMode::TTS) {
315 let tts_graphs = match TTS_GRAPHS.get() {
316 Some(tts_graphs) => tts_graphs,
317 None => {
318 let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
319
320 #[cfg(feature = "logging")]
321 error!(target: "stdout", "{err_msg}");
322
323 return Err(LlamaCoreError::Operation(err_msg.into()));
324 }
325 };
326
327 let tts_graphs = tts_graphs.lock().map_err(|e| {
328 let err_msg = format!("Fail to acquire the lock of `TTS_GRAPHS`. {e}");
329
330 #[cfg(feature = "logging")]
331 error!(target: "stdout", "{}", &err_msg);
332
333 LlamaCoreError::Operation(err_msg)
334 })?;
335
336 let graph = match tts_graphs.values().next() {
337 Some(graph) => graph,
338 None => {
339 let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
340
341 #[cfg(feature = "logging")]
342 error!(target: "stdout", "{err_msg}");
343
344 return Err(LlamaCoreError::Operation(err_msg.into()));
345 }
346 };
347
348 get_plugin_info_by_graph(graph)
349 } else {
350 let err_msg = "RUNNING_MODE is not set";
351
352 #[cfg(feature = "logging")]
353 error!(target: "stdout", "{err_msg}");
354
355 Err(LlamaCoreError::Operation(err_msg.into()))
356 }
357}
358
359fn get_plugin_info_by_graph<M: BaseMetadata + serde::Serialize + Clone + Default>(
360 graph: &Graph<M>,
361) -> Result<PluginInfo, LlamaCoreError> {
362 #[cfg(feature = "logging")]
363 debug!(target: "stdout", "Getting the plugin info by the graph named {}", graph.name());
364
365 let output_buffer = get_output_buffer(graph, PLUGIN_VERSION)?;
367 let metadata: serde_json::Value = serde_json::from_slice(&output_buffer[..]).map_err(|e| {
368 let err_msg = format!("Fail to deserialize the plugin metadata. {e}");
369
370 #[cfg(feature = "logging")]
371 error!(target: "stdout", "{}", &err_msg);
372
373 LlamaCoreError::Operation(err_msg)
374 })?;
375
376 let plugin_build_number = match metadata.get("llama_build_number") {
378 Some(value) => match value.as_u64() {
379 Some(number) => number,
380 None => {
381 let err_msg = "Failed to convert the build number of the plugin to u64";
382
383 #[cfg(feature = "logging")]
384 error!(target: "stdout", "{err_msg}");
385
386 return Err(LlamaCoreError::Operation(err_msg.into()));
387 }
388 },
389 None => {
390 let err_msg = "Metadata does not have the field `llama_build_number`.";
391
392 #[cfg(feature = "logging")]
393 error!(target: "stdout", "{err_msg}");
394
395 return Err(LlamaCoreError::Operation(err_msg.into()));
396 }
397 };
398
399 let plugin_commit = match metadata.get("llama_commit") {
401 Some(value) => match value.as_str() {
402 Some(commit) => commit,
403 None => {
404 let err_msg = "Failed to convert the commit id of the plugin to string";
405
406 #[cfg(feature = "logging")]
407 error!(target: "stdout", "{err_msg}");
408
409 return Err(LlamaCoreError::Operation(err_msg.into()));
410 }
411 },
412 None => {
413 let err_msg = "Metadata does not have the field `llama_commit`.";
414
415 #[cfg(feature = "logging")]
416 error!(target: "stdout", "{err_msg}");
417
418 return Err(LlamaCoreError::Operation(err_msg.into()));
419 }
420 };
421
422 #[cfg(feature = "logging")]
423 debug!(target: "stdout", "Plugin info: b{plugin_build_number}(commit {plugin_commit})");
424
425 Ok(PluginInfo {
426 build_number: plugin_build_number,
427 commit_id: plugin_commit.to_string(),
428 })
429}
430
431#[derive(Debug, Clone)]
433pub struct PluginInfo {
434 pub build_number: u64,
435 pub commit_id: String,
436}
437impl std::fmt::Display for PluginInfo {
438 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
439 write!(
440 f,
441 "wasinn-ggml plugin: b{}(commit {})",
442 self.build_number, self.commit_id
443 )
444 }
445}
446
447pub fn running_mode() -> Result<RunningMode, LlamaCoreError> {
449 #[cfg(feature = "logging")]
450 debug!(target: "stdout", "Get the running mode.");
451
452 match RUNNING_MODE.get() {
453 Some(mode) => match mode.read() {
454 Ok(mode) => Ok(*mode),
455 Err(e) => {
456 let err_msg = format!("Fail to get the underlying value of `RUNNING_MODE`. {e}");
457
458 #[cfg(feature = "logging")]
459 error!(target: "stdout", "{err_msg}");
460
461 Err(LlamaCoreError::Operation(err_msg))
462 }
463 },
464 None => {
465 let err_msg = "Fail to get the underlying value of `RUNNING_MODE`.";
466
467 #[cfg(feature = "logging")]
468 error!(target: "stdout", "{err_msg}");
469
470 Err(LlamaCoreError::Operation(err_msg.into()))
471 }
472 }
473}
474
475#[allow(clippy::too_many_arguments)]
495pub fn init_sd_context_with_full_model(
496 model_file: impl AsRef<str>,
497 lora_model_dir: Option<&str>,
498 controlnet_path: Option<&str>,
499 controlnet_on_cpu: bool,
500 clip_on_cpu: bool,
501 vae_on_cpu: bool,
502 n_threads: i32,
503 task: StableDiffusionTask,
504) -> Result<(), LlamaCoreError> {
505 #[cfg(feature = "logging")]
506 info!(target: "stdout", "Initializing the stable diffusion context with the full model");
507
508 let control_net_on_cpu = match controlnet_path {
509 Some(path) if !path.is_empty() => controlnet_on_cpu,
510 _ => false,
511 };
512
513 if task == StableDiffusionTask::Full || task == StableDiffusionTask::TextToImage {
515 let sd = SDBuidler::new(Task::TextToImage, model_file.as_ref())
516 .map_err(|e| {
517 let err_msg =
518 format!("Failed to initialize the stable diffusion context. Reason: {e}");
519
520 #[cfg(feature = "logging")]
521 error!(target: "stdout", "{err_msg}");
522
523 LlamaCoreError::InitContext(err_msg)
524 })?
525 .with_lora_model_dir(lora_model_dir.unwrap_or_default())
526 .map_err(|e| {
527 let err_msg =
528 format!("Failed to initialize the stable diffusion context. Reason: {e}");
529
530 #[cfg(feature = "logging")]
531 error!(target: "stdout", "{err_msg}");
532
533 LlamaCoreError::InitContext(err_msg)
534 })?
535 .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
536 .map_err(|e| {
537 let err_msg =
538 format!("Failed to initialize the stable diffusion context. Reason: {e}");
539
540 #[cfg(feature = "logging")]
541 error!(target: "stdout", "{err_msg}");
542
543 LlamaCoreError::InitContext(err_msg)
544 })?
545 .clip_on_cpu(clip_on_cpu)
546 .vae_on_cpu(vae_on_cpu)
547 .with_n_threads(n_threads)
548 .build();
549
550 #[cfg(feature = "logging")]
551 info!(target: "stdout", "sd: {:?}", &sd);
552
553 let ctx = sd.create_context().map_err(|e| {
554 let err_msg = format!("Fail to create the context. {e}");
555
556 #[cfg(feature = "logging")]
557 error!(target: "stdout", "{}", &err_msg);
558
559 LlamaCoreError::InitContext(err_msg)
560 })?;
561
562 let ctx = match ctx {
563 Context::TextToImage(ctx) => ctx,
564 _ => {
565 let err_msg = "Fail to get the context for the text-to-image task";
566
567 #[cfg(feature = "logging")]
568 error!(target: "stdout", "{err_msg}");
569
570 return Err(LlamaCoreError::InitContext(err_msg.into()));
571 }
572 };
573
574 #[cfg(feature = "logging")]
575 info!(target: "stdout", "sd text_to_image context: {:?}", &ctx);
576
577 SD_TEXT_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
578 let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_TEXT_TO_IMAGE` has already been initialized";
579
580 #[cfg(feature = "logging")]
581 error!(target: "stdout", "{err_msg}");
582
583 LlamaCoreError::InitContext(err_msg.into())
584 })?;
585
586 #[cfg(feature = "logging")]
587 info!(target: "stdout", "The stable diffusion text-to-image context has been initialized");
588 }
589
590 if task == StableDiffusionTask::Full || task == StableDiffusionTask::ImageToImage {
592 let sd = SDBuidler::new(Task::ImageToImage, model_file.as_ref())
593 .map_err(|e| {
594 let err_msg =
595 format!("Failed to initialize the stable diffusion context. Reason: {e}");
596
597 #[cfg(feature = "logging")]
598 error!(target: "stdout", "{err_msg}");
599
600 LlamaCoreError::InitContext(err_msg)
601 })?
602 .with_lora_model_dir(lora_model_dir.unwrap_or_default())
603 .map_err(|e| {
604 let err_msg =
605 format!("Failed to initialize the stable diffusion context. Reason: {e}");
606
607 #[cfg(feature = "logging")]
608 error!(target: "stdout", "{err_msg}");
609
610 LlamaCoreError::InitContext(err_msg)
611 })?
612 .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
613 .map_err(|e| {
614 let err_msg =
615 format!("Failed to initialize the stable diffusion context. Reason: {e}");
616
617 #[cfg(feature = "logging")]
618 error!(target: "stdout", "{err_msg}");
619
620 LlamaCoreError::InitContext(err_msg)
621 })?
622 .clip_on_cpu(clip_on_cpu)
623 .vae_on_cpu(vae_on_cpu)
624 .with_n_threads(n_threads)
625 .build();
626
627 #[cfg(feature = "logging")]
628 info!(target: "stdout", "sd: {:?}", &sd);
629
630 let ctx = sd.create_context().map_err(|e| {
631 let err_msg = format!("Fail to create the context. {e}");
632
633 #[cfg(feature = "logging")]
634 error!(target: "stdout", "{}", &err_msg);
635
636 LlamaCoreError::InitContext(err_msg)
637 })?;
638
639 let ctx = match ctx {
640 Context::ImageToImage(ctx) => ctx,
641 _ => {
642 let err_msg = "Fail to get the context for the image-to-image task";
643
644 #[cfg(feature = "logging")]
645 error!(target: "stdout", "{err_msg}");
646
647 return Err(LlamaCoreError::InitContext(err_msg.into()));
648 }
649 };
650
651 #[cfg(feature = "logging")]
652 info!(target: "stdout", "sd image_to_image context: {:?}", &ctx);
653
654 SD_IMAGE_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
655 let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_IMAGE_TO_IMAGE` has already been initialized";
656
657 #[cfg(feature = "logging")]
658 error!(target: "stdout", "{err_msg}");
659
660 LlamaCoreError::InitContext(err_msg.into())
661 })?;
662
663 #[cfg(feature = "logging")]
664 info!(target: "stdout", "The stable diffusion image-to-image context has been initialized");
665 }
666
667 Ok(())
668}
669
670#[allow(clippy::too_many_arguments)]
696pub fn init_sd_context_with_standalone_model(
697 model_file: impl AsRef<str>,
698 vae: impl AsRef<str>,
699 clip_l: impl AsRef<str>,
700 t5xxl: impl AsRef<str>,
701 lora_model_dir: Option<&str>,
702 controlnet_path: Option<&str>,
703 controlnet_on_cpu: bool,
704 clip_on_cpu: bool,
705 vae_on_cpu: bool,
706 n_threads: i32,
707 task: StableDiffusionTask,
708) -> Result<(), LlamaCoreError> {
709 #[cfg(feature = "logging")]
710 info!(target: "stdout", "Initializing the stable diffusion context with the standalone diffusion model");
711
712 let control_net_on_cpu = match controlnet_path {
713 Some(path) if !path.is_empty() => controlnet_on_cpu,
714 _ => false,
715 };
716
717 if task == StableDiffusionTask::Full || task == StableDiffusionTask::TextToImage {
719 let sd = SDBuidler::new_with_standalone_model(Task::TextToImage, model_file.as_ref())
720 .map_err(|e| {
721 let err_msg =
722 format!("Failed to initialize the stable diffusion context. Reason: {e}");
723
724 #[cfg(feature = "logging")]
725 error!(target: "stdout", "{err_msg}");
726
727 LlamaCoreError::InitContext(err_msg)
728 })?
729 .with_vae_path(vae.as_ref())
730 .map_err(|e| {
731 let err_msg =
732 format!("Failed to initialize the stable diffusion context. Reason: {e}");
733
734 #[cfg(feature = "logging")]
735 error!(target: "stdout", "{err_msg}");
736
737 LlamaCoreError::InitContext(err_msg)
738 })?
739 .with_clip_l_path(clip_l.as_ref())
740 .map_err(|e| {
741 let err_msg =
742 format!("Failed to initialize the stable diffusion context. Reason: {e}");
743
744 #[cfg(feature = "logging")]
745 error!(target: "stdout", "{err_msg}");
746
747 LlamaCoreError::InitContext(err_msg)
748 })?
749 .with_t5xxl_path(t5xxl.as_ref())
750 .map_err(|e| {
751 let err_msg =
752 format!("Failed to initialize the stable diffusion context. Reason: {e}");
753
754 #[cfg(feature = "logging")]
755 error!(target: "stdout", "{err_msg}");
756
757 LlamaCoreError::InitContext(err_msg)
758 })?
759 .with_lora_model_dir(lora_model_dir.unwrap_or_default())
760 .map_err(|e| {
761 let err_msg =
762 format!("Failed to initialize the stable diffusion context. Reason: {e}");
763
764 #[cfg(feature = "logging")]
765 error!(target: "stdout", "{err_msg}");
766
767 LlamaCoreError::InitContext(err_msg)
768 })?
769 .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
770 .map_err(|e| {
771 let err_msg =
772 format!("Failed to initialize the stable diffusion context. Reason: {e}");
773
774 #[cfg(feature = "logging")]
775 error!(target: "stdout", "{err_msg}");
776
777 LlamaCoreError::InitContext(err_msg)
778 })?
779 .clip_on_cpu(clip_on_cpu)
780 .vae_on_cpu(vae_on_cpu)
781 .with_n_threads(n_threads)
782 .build();
783
784 #[cfg(feature = "logging")]
785 info!(target: "stdout", "sd: {:?}", &sd);
786
787 let ctx = sd.create_context().map_err(|e| {
788 let err_msg = format!("Fail to create the context. {e}");
789
790 #[cfg(feature = "logging")]
791 error!(target: "stdout", "{}", &err_msg);
792
793 LlamaCoreError::InitContext(err_msg)
794 })?;
795
796 let ctx = match ctx {
797 Context::TextToImage(ctx) => ctx,
798 _ => {
799 let err_msg = "Fail to get the context for the text-to-image task";
800
801 #[cfg(feature = "logging")]
802 error!(target: "stdout", "{err_msg}");
803
804 return Err(LlamaCoreError::InitContext(err_msg.into()));
805 }
806 };
807
808 #[cfg(feature = "logging")]
809 info!(target: "stdout", "sd text_to_image context: {:?}", &ctx);
810
811 SD_TEXT_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
812 let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_TEXT_TO_IMAGE` has already been initialized";
813
814 #[cfg(feature = "logging")]
815 error!(target: "stdout", "{err_msg}");
816
817 LlamaCoreError::InitContext(err_msg.into())
818 })?;
819
820 #[cfg(feature = "logging")]
821 info!(target: "stdout", "The stable diffusion text-to-image context has been initialized");
822 }
823
824 if task == StableDiffusionTask::Full || task == StableDiffusionTask::ImageToImage {
826 let sd = SDBuidler::new_with_standalone_model(Task::ImageToImage, model_file.as_ref())
827 .map_err(|e| {
828 let err_msg =
829 format!("Failed to initialize the stable diffusion context. Reason: {e}");
830
831 #[cfg(feature = "logging")]
832 error!(target: "stdout", "{err_msg}");
833
834 LlamaCoreError::InitContext(err_msg)
835 })?
836 .with_vae_path(vae.as_ref())
837 .map_err(|e| {
838 let err_msg =
839 format!("Failed to initialize the stable diffusion context. Reason: {e}");
840
841 #[cfg(feature = "logging")]
842 error!(target: "stdout", "{err_msg}");
843
844 LlamaCoreError::InitContext(err_msg)
845 })?
846 .with_clip_l_path(clip_l.as_ref())
847 .map_err(|e| {
848 let err_msg =
849 format!("Failed to initialize the stable diffusion context. Reason: {e}");
850
851 #[cfg(feature = "logging")]
852 error!(target: "stdout", "{err_msg}");
853
854 LlamaCoreError::InitContext(err_msg)
855 })?
856 .with_t5xxl_path(t5xxl.as_ref())
857 .map_err(|e| {
858 let err_msg =
859 format!("Failed to initialize the stable diffusion context. Reason: {e}");
860
861 #[cfg(feature = "logging")]
862 error!(target: "stdout", "{err_msg}");
863
864 LlamaCoreError::InitContext(err_msg)
865 })?
866 .with_lora_model_dir(lora_model_dir.unwrap_or_default())
867 .map_err(|e| {
868 let err_msg =
869 format!("Failed to initialize the stable diffusion context. Reason: {e}");
870
871 #[cfg(feature = "logging")]
872 error!(target: "stdout", "{err_msg}");
873
874 LlamaCoreError::InitContext(err_msg)
875 })?
876 .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
877 .map_err(|e| {
878 let err_msg =
879 format!("Failed to initialize the stable diffusion context. Reason: {e}");
880
881 #[cfg(feature = "logging")]
882 error!(target: "stdout", "{err_msg}");
883
884 LlamaCoreError::InitContext(err_msg)
885 })?
886 .clip_on_cpu(clip_on_cpu)
887 .vae_on_cpu(vae_on_cpu)
888 .with_n_threads(n_threads)
889 .build();
890
891 #[cfg(feature = "logging")]
892 info!(target: "stdout", "sd: {:?}", &sd);
893
894 let ctx = sd.create_context().map_err(|e| {
895 let err_msg = format!("Fail to create the context. {e}");
896
897 #[cfg(feature = "logging")]
898 error!(target: "stdout", "{}", &err_msg);
899
900 LlamaCoreError::InitContext(err_msg)
901 })?;
902
903 let ctx = match ctx {
904 Context::ImageToImage(ctx) => ctx,
905 _ => {
906 let err_msg = "Fail to get the context for the image-to-image task";
907
908 #[cfg(feature = "logging")]
909 error!(target: "stdout", "{err_msg}");
910
911 return Err(LlamaCoreError::InitContext(err_msg.into()));
912 }
913 };
914
915 #[cfg(feature = "logging")]
916 info!(target: "stdout", "sd image_to_image context: {:?}", &ctx);
917
918 SD_IMAGE_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
919 let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_IMAGE_TO_IMAGE` has already been initialized";
920
921 #[cfg(feature = "logging")]
922 error!(target: "stdout", "{err_msg}");
923
924 LlamaCoreError::InitContext(err_msg.into())
925 })?;
926
927 #[cfg(feature = "logging")]
928 info!(target: "stdout", "The stable diffusion image-to-image context has been initialized");
929 }
930
931 Ok(())
932}
933
934#[derive(Clone, Debug, Copy, PartialEq, Eq)]
936pub enum StableDiffusionTask {
937 TextToImage,
939 ImageToImage,
941 Full,
943}
944
945#[cfg(feature = "whisper")]
947pub fn init_whisper_context(whisper_metadata: &WhisperMetadata) -> Result<(), LlamaCoreError> {
948 let graph = GraphBuilder::new(EngineType::Whisper)?
950 .with_config(whisper_metadata.clone())?
951 .use_cpu()
952 .build_from_files([&whisper_metadata.model_path])?;
953
954 match AUDIO_GRAPH.get() {
955 Some(mutex_graph) => {
956 #[cfg(feature = "logging")]
957 info!(target: "stdout", "Re-initialize the audio context");
958
959 match mutex_graph.lock() {
960 Ok(mut locked_graph) => *locked_graph = graph,
961 Err(e) => {
962 let err_msg = format!("Failed to lock the graph. Reason: {e}");
963
964 #[cfg(feature = "logging")]
965 error!(target: "stdout", "{err_msg}");
966
967 return Err(LlamaCoreError::InitContext(err_msg));
968 }
969 }
970 }
971 None => {
972 #[cfg(feature = "logging")]
973 info!(target: "stdout", "Initialize the audio context");
974
975 AUDIO_GRAPH.set(Mutex::new(graph)).map_err(|_| {
976 let err_msg = "Failed to initialize the audio context. Reason: The `AUDIO_GRAPH` has already been initialized";
977
978 #[cfg(feature = "logging")]
979 error!(target: "stdout", "{err_msg}");
980
981 LlamaCoreError::InitContext(err_msg.into())
982 })?;
983 }
984 }
985
986 #[cfg(feature = "logging")]
987 info!(target: "stdout", "The audio context has been initialized");
988
989 Ok(())
990}
991
992pub fn init_piper_context(
1003 piper_metadata: &PiperMetadata,
1004 voice_model: impl AsRef<Path>,
1005 voice_config: impl AsRef<Path>,
1006 espeak_ng_data: impl AsRef<Path>,
1007) -> Result<(), LlamaCoreError> {
1008 #[cfg(feature = "logging")]
1009 info!(target: "stdout", "Initializing the piper context");
1010
1011 let config = serde_json::json!({
1012 "model": voice_model.as_ref().to_owned(),
1013 "config": voice_config.as_ref().to_owned(),
1014 "espeak_data": espeak_ng_data.as_ref().to_owned(),
1015 });
1016
1017 let graph = GraphBuilder::new(EngineType::Piper)?
1019 .with_config(piper_metadata.clone())?
1020 .use_cpu()
1021 .build_from_buffer([config.to_string()])?;
1022
1023 PIPER_GRAPH.set(Mutex::new(graph)).map_err(|_| {
1024 let err_msg = "Failed to initialize the piper context. Reason: The `PIPER_GRAPH` has already been initialized";
1025
1026 #[cfg(feature = "logging")]
1027 error!(target: "stdout", "{err_msg}");
1028
1029 LlamaCoreError::InitContext(err_msg.into())
1030 })?;
1031
1032 #[cfg(feature = "logging")]
1033 info!(target: "stdout", "The piper context has been initialized");
1034
1035 Ok(())
1036}