llama_core/
lib.rs

1//! Llama Core, abbreviated as `llama-core`, defines a set of APIs. Developers can utilize these APIs to build applications based on large models, such as chatbots, RAG, and more.
2
3#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
4
5#[cfg(feature = "logging")]
6#[macro_use]
7extern crate log;
8
9pub mod audio;
10pub mod chat;
11pub mod completions;
12pub mod embeddings;
13pub mod error;
14pub mod files;
15pub mod graph;
16pub mod images;
17pub mod metadata;
18pub mod models;
19#[cfg(feature = "rag")]
20#[cfg_attr(docsrs, doc(cfg(feature = "rag")))]
21pub mod rag;
22#[cfg(feature = "search")]
23#[cfg_attr(docsrs, doc(cfg(feature = "search")))]
24pub mod search;
25pub mod tts;
26pub mod utils;
27
28pub use error::LlamaCoreError;
29pub use graph::{EngineType, Graph, GraphBuilder};
30#[cfg(feature = "whisper")]
31use metadata::whisper::WhisperMetadata;
32pub use metadata::{
33    ggml::{GgmlMetadata, GgmlTtsMetadata},
34    piper::PiperMetadata,
35    BaseMetadata,
36};
37use once_cell::sync::OnceCell;
38use std::{
39    collections::HashMap,
40    path::Path,
41    sync::{Mutex, RwLock},
42};
43use utils::{get_output_buffer, RunningMode};
44use wasmedge_stable_diffusion::*;
45
46// key: model_name, value: Graph
47pub(crate) static CHAT_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlMetadata>>>> =
48    OnceCell::new();
49// key: model_name, value: Graph
50pub(crate) static EMBEDDING_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlMetadata>>>> =
51    OnceCell::new();
52// key: model_name, value: Graph
53pub(crate) static TTS_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlTtsMetadata>>>> =
54    OnceCell::new();
55// cache bytes for decoding utf8
56pub(crate) static CACHED_UTF8_ENCODINGS: OnceCell<Mutex<Vec<u8>>> = OnceCell::new();
57// running mode
58pub(crate) static RUNNING_MODE: OnceCell<RwLock<RunningMode>> = OnceCell::new();
59// stable diffusion context for the text-to-image task
60pub(crate) static SD_TEXT_TO_IMAGE: OnceCell<Mutex<TextToImage>> = OnceCell::new();
61// stable diffusion context for the image-to-image task
62pub(crate) static SD_IMAGE_TO_IMAGE: OnceCell<Mutex<ImageToImage>> = OnceCell::new();
63// context for the audio task
64#[cfg(feature = "whisper")]
65pub(crate) static AUDIO_GRAPH: OnceCell<Mutex<Graph<WhisperMetadata>>> = OnceCell::new();
66// context for the piper task
67pub(crate) static PIPER_GRAPH: OnceCell<Mutex<Graph<PiperMetadata>>> = OnceCell::new();
68
69pub(crate) const MAX_BUFFER_SIZE: usize = 2usize.pow(14) * 15 + 128;
70pub(crate) const OUTPUT_TENSOR: usize = 0;
71const PLUGIN_VERSION: usize = 1;
72
73/// The directory for storing the archives in wasm virtual file system.
74pub const ARCHIVES_DIR: &str = "archives";
75
76/// Initialize the ggml context
77pub fn init_ggml_chat_context(metadata_for_chats: &[GgmlMetadata]) -> Result<(), LlamaCoreError> {
78    #[cfg(feature = "logging")]
79    info!(target: "stdout", "Initializing the core context");
80
81    if metadata_for_chats.is_empty() {
82        let err_msg = "The metadata for chat models is empty";
83
84        #[cfg(feature = "logging")]
85        error!(target: "stdout", "{err_msg}");
86
87        return Err(LlamaCoreError::InitContext(err_msg.into()));
88    }
89
90    let mut chat_graphs = HashMap::new();
91    for metadata in metadata_for_chats {
92        let graph = Graph::new(metadata.clone())?;
93
94        chat_graphs.insert(graph.name().to_string(), graph);
95    }
96    CHAT_GRAPHS.set(Mutex::new(chat_graphs)).map_err(|_| {
97            let err_msg = "Failed to initialize the core context. Reason: The `CHAT_GRAPHS` has already been initialized";
98
99            #[cfg(feature = "logging")]
100            error!(target: "stdout", "{err_msg}");
101
102            LlamaCoreError::InitContext(err_msg.into())
103        })?;
104
105    // set running mode
106    let running_mode = RunningMode::CHAT;
107    match RUNNING_MODE.get() {
108        Some(mode) => {
109            let mut mode = mode.write().unwrap();
110            *mode |= running_mode;
111        }
112        None => {
113            RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
114                let err_msg = "Failed to initialize the chat context. Reason: The `RUNNING_MODE` has already been initialized";
115
116                #[cfg(feature = "logging")]
117                error!(target: "stdout", "{err_msg}");
118
119                LlamaCoreError::InitContext(err_msg.into())
120            })?;
121        }
122    }
123
124    Ok(())
125}
126
127/// Initialize the ggml context
128pub fn init_ggml_embeddings_context(
129    metadata_for_embeddings: &[GgmlMetadata],
130) -> Result<(), LlamaCoreError> {
131    #[cfg(feature = "logging")]
132    info!(target: "stdout", "Initializing the embeddings context");
133
134    if metadata_for_embeddings.is_empty() {
135        let err_msg = "The metadata for chat models is empty";
136
137        #[cfg(feature = "logging")]
138        error!(target: "stdout", "{err_msg}");
139
140        return Err(LlamaCoreError::InitContext(err_msg.into()));
141    }
142
143    let mut embedding_graphs = HashMap::new();
144    for metadata in metadata_for_embeddings {
145        let graph = Graph::new(metadata.clone())?;
146
147        embedding_graphs.insert(graph.name().to_string(), graph);
148    }
149    EMBEDDING_GRAPHS
150            .set(Mutex::new(embedding_graphs))
151            .map_err(|_| {
152                let err_msg = "Failed to initialize the core context. Reason: The `EMBEDDING_GRAPHS` has already been initialized";
153
154                #[cfg(feature = "logging")]
155                error!(target: "stdout", "{err_msg}");
156
157                LlamaCoreError::InitContext(err_msg.into())
158            })?;
159
160    // set running mode
161    let running_mode = RunningMode::EMBEDDINGS;
162    match RUNNING_MODE.get() {
163        Some(mode) => {
164            let mut mode = mode.write().unwrap();
165            *mode |= running_mode;
166        }
167        None => {
168            RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
169                let err_msg = "Failed to initialize the embeddings context. Reason: The `RUNNING_MODE` has already been initialized";
170
171                #[cfg(feature = "logging")]
172                error!(target: "stdout", "{err_msg}");
173
174                LlamaCoreError::InitContext(err_msg.into())
175            })?;
176        }
177    }
178
179    Ok(())
180}
181
182/// Initialize the ggml context for RAG scenarios.
183#[cfg(feature = "rag")]
184pub fn init_ggml_rag_context(
185    metadata_for_chats: &[GgmlMetadata],
186    metadata_for_embeddings: &[GgmlMetadata],
187) -> Result<(), LlamaCoreError> {
188    #[cfg(feature = "logging")]
189    info!(target: "stdout", "Initializing the core context for RAG scenarios");
190
191    // chat models
192    if metadata_for_chats.is_empty() {
193        let err_msg = "The metadata for chat models is empty";
194
195        #[cfg(feature = "logging")]
196        error!(target: "stdout", "{err_msg}");
197
198        return Err(LlamaCoreError::InitContext(err_msg.into()));
199    }
200    let mut chat_graphs = HashMap::new();
201    for metadata in metadata_for_chats {
202        let graph = Graph::new(metadata.clone())?;
203
204        chat_graphs.insert(graph.name().to_string(), graph);
205    }
206    CHAT_GRAPHS.set(Mutex::new(chat_graphs)).map_err(|_| {
207        let err_msg = "Failed to initialize the core context. Reason: The `CHAT_GRAPHS` has already been initialized";
208
209        #[cfg(feature = "logging")]
210        error!(target: "stdout", "{err_msg}");
211
212        LlamaCoreError::InitContext(err_msg.into())
213    })?;
214
215    // embedding models
216    if metadata_for_embeddings.is_empty() {
217        let err_msg = "The metadata for embeddings is empty";
218
219        #[cfg(feature = "logging")]
220        error!(target: "stdout", "{err_msg}");
221
222        return Err(LlamaCoreError::InitContext(err_msg.into()));
223    }
224    let mut embedding_graphs = HashMap::new();
225    for metadata in metadata_for_embeddings {
226        let graph = Graph::new(metadata.clone())?;
227
228        embedding_graphs.insert(graph.name().to_string(), graph);
229    }
230    EMBEDDING_GRAPHS
231        .set(Mutex::new(embedding_graphs))
232        .map_err(|_| {
233            let err_msg = "Failed to initialize the core context. Reason: The `EMBEDDING_GRAPHS` has already been initialized";
234
235            #[cfg(feature = "logging")]
236            error!(target: "stdout", "{err_msg}");
237
238            LlamaCoreError::InitContext(err_msg.into())
239        })?;
240
241    // set running mode
242    let running_mode = RunningMode::RAG;
243    match RUNNING_MODE.get() {
244        Some(mode) => {
245            let mut mode = mode.write().unwrap();
246            *mode |= running_mode;
247        }
248        None => {
249            RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
250                    let err_msg = "Failed to initialize the rag context. Reason: The `RUNNING_MODE` has already been initialized";
251
252                    #[cfg(feature = "logging")]
253                    error!(target: "stdout", "{err_msg}");
254
255                    LlamaCoreError::InitContext(err_msg.into())
256                })?;
257        }
258    }
259
260    Ok(())
261}
262
263/// Initialize the ggml context for TTS scenarios.
264pub fn init_ggml_tts_context(metadata_for_tts: &[GgmlTtsMetadata]) -> Result<(), LlamaCoreError> {
265    #[cfg(feature = "logging")]
266    info!(target: "stdout", "Initializing the TTS context");
267
268    if metadata_for_tts.is_empty() {
269        let err_msg = "The metadata for tts models is empty";
270
271        #[cfg(feature = "logging")]
272        error!(target: "stdout", "{err_msg}");
273
274        return Err(LlamaCoreError::InitContext(err_msg.into()));
275    }
276
277    let mut tts_graphs = HashMap::new();
278    for metadata in metadata_for_tts {
279        let graph = Graph::new(metadata.clone())?;
280
281        tts_graphs.insert(graph.name().to_string(), graph);
282    }
283    TTS_GRAPHS.set(Mutex::new(tts_graphs)).map_err(|_| {
284        let err_msg = "Failed to initialize the core context. Reason: The `TTS_GRAPHS` has already been initialized";
285
286        #[cfg(feature = "logging")]
287        error!(target: "stdout", "{err_msg}");
288
289        LlamaCoreError::InitContext(err_msg.into())
290    })?;
291
292    // set running mode
293    let running_mode = RunningMode::TTS;
294    match RUNNING_MODE.get() {
295        Some(mode) => {
296            let mut mode = mode.write().unwrap();
297            *mode |= running_mode;
298        }
299        None => {
300            RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
301                let err_msg = "Failed to initialize the embeddings context. Reason: The `RUNNING_MODE` has already been initialized";
302
303                #[cfg(feature = "logging")]
304                error!(target: "stdout", "{err_msg}");
305
306                LlamaCoreError::InitContext(err_msg.into())
307            })?;
308        }
309    }
310
311    Ok(())
312}
313
314/// Get the plugin info
315///
316/// Note that it is required to call `init_core_context` before calling this function.
317pub fn get_plugin_info() -> Result<PluginInfo, LlamaCoreError> {
318    #[cfg(feature = "logging")]
319    debug!(target: "stdout", "Getting the plugin info");
320
321    let running_mode = running_mode()?;
322
323    if running_mode.contains(RunningMode::CHAT) || running_mode.contains(RunningMode::RAG) {
324        let chat_graphs = match CHAT_GRAPHS.get() {
325            Some(chat_graphs) => chat_graphs,
326            None => {
327                let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
328
329                #[cfg(feature = "logging")]
330                error!(target: "stdout", "{err_msg}");
331
332                return Err(LlamaCoreError::Operation(err_msg.into()));
333            }
334        };
335
336        let chat_graphs = chat_graphs.lock().map_err(|e| {
337            let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
338
339            #[cfg(feature = "logging")]
340            error!(target: "stdout", "{}", &err_msg);
341
342            LlamaCoreError::Operation(err_msg)
343        })?;
344
345        let graph = match chat_graphs.values().next() {
346            Some(graph) => graph,
347            None => {
348                let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
349
350                #[cfg(feature = "logging")]
351                error!(target: "stdout", "{err_msg}");
352
353                return Err(LlamaCoreError::Operation(err_msg.into()));
354            }
355        };
356
357        get_plugin_info_by_graph(graph)
358    } else if running_mode.contains(RunningMode::EMBEDDINGS) {
359        let embedding_graphs = match EMBEDDING_GRAPHS.get() {
360            Some(embedding_graphs) => embedding_graphs,
361            None => {
362                let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
363
364                #[cfg(feature = "logging")]
365                error!(target: "stdout", "{err_msg}");
366
367                return Err(LlamaCoreError::Operation(err_msg.into()));
368            }
369        };
370
371        let embedding_graphs = embedding_graphs.lock().map_err(|e| {
372            let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {e}");
373
374            #[cfg(feature = "logging")]
375            error!(target: "stdout", "{}", &err_msg);
376
377            LlamaCoreError::Operation(err_msg)
378        })?;
379
380        let graph = match embedding_graphs.values().next() {
381            Some(graph) => graph,
382            None => {
383                let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
384
385                #[cfg(feature = "logging")]
386                error!(target: "stdout", "{err_msg}");
387
388                return Err(LlamaCoreError::Operation(err_msg.into()));
389            }
390        };
391
392        get_plugin_info_by_graph(graph)
393    } else if running_mode.contains(RunningMode::TTS) {
394        let tts_graphs = match TTS_GRAPHS.get() {
395            Some(tts_graphs) => tts_graphs,
396            None => {
397                let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
398
399                #[cfg(feature = "logging")]
400                error!(target: "stdout", "{err_msg}");
401
402                return Err(LlamaCoreError::Operation(err_msg.into()));
403            }
404        };
405
406        let tts_graphs = tts_graphs.lock().map_err(|e| {
407            let err_msg = format!("Fail to acquire the lock of `TTS_GRAPHS`. {e}");
408
409            #[cfg(feature = "logging")]
410            error!(target: "stdout", "{}", &err_msg);
411
412            LlamaCoreError::Operation(err_msg)
413        })?;
414
415        let graph = match tts_graphs.values().next() {
416            Some(graph) => graph,
417            None => {
418                let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
419
420                #[cfg(feature = "logging")]
421                error!(target: "stdout", "{err_msg}");
422
423                return Err(LlamaCoreError::Operation(err_msg.into()));
424            }
425        };
426
427        get_plugin_info_by_graph(graph)
428    } else {
429        let err_msg = "RUNNING_MODE is not set";
430
431        #[cfg(feature = "logging")]
432        error!(target: "stdout", "{err_msg}");
433
434        Err(LlamaCoreError::Operation(err_msg.into()))
435    }
436}
437
438fn get_plugin_info_by_graph<M: BaseMetadata + serde::Serialize + Clone + Default>(
439    graph: &Graph<M>,
440) -> Result<PluginInfo, LlamaCoreError> {
441    #[cfg(feature = "logging")]
442    debug!(target: "stdout", "Getting the plugin info by the graph named {}", graph.name());
443
444    // get the plugin metadata
445    let output_buffer = get_output_buffer(graph, PLUGIN_VERSION)?;
446    let metadata: serde_json::Value = serde_json::from_slice(&output_buffer[..]).map_err(|e| {
447        let err_msg = format!("Fail to deserialize the plugin metadata. {e}");
448
449        #[cfg(feature = "logging")]
450        error!(target: "stdout", "{}", &err_msg);
451
452        LlamaCoreError::Operation(err_msg)
453    })?;
454
455    // get build number of the plugin
456    let plugin_build_number = match metadata.get("llama_build_number") {
457        Some(value) => match value.as_u64() {
458            Some(number) => number,
459            None => {
460                let err_msg = "Failed to convert the build number of the plugin to u64";
461
462                #[cfg(feature = "logging")]
463                error!(target: "stdout", "{err_msg}");
464
465                return Err(LlamaCoreError::Operation(err_msg.into()));
466            }
467        },
468        None => {
469            let err_msg = "Metadata does not have the field `llama_build_number`.";
470
471            #[cfg(feature = "logging")]
472            error!(target: "stdout", "{err_msg}");
473
474            return Err(LlamaCoreError::Operation(err_msg.into()));
475        }
476    };
477
478    // get commit id of the plugin
479    let plugin_commit = match metadata.get("llama_commit") {
480        Some(value) => match value.as_str() {
481            Some(commit) => commit,
482            None => {
483                let err_msg = "Failed to convert the commit id of the plugin to string";
484
485                #[cfg(feature = "logging")]
486                error!(target: "stdout", "{err_msg}");
487
488                return Err(LlamaCoreError::Operation(err_msg.into()));
489            }
490        },
491        None => {
492            let err_msg = "Metadata does not have the field `llama_commit`.";
493
494            #[cfg(feature = "logging")]
495            error!(target: "stdout", "{err_msg}");
496
497            return Err(LlamaCoreError::Operation(err_msg.into()));
498        }
499    };
500
501    #[cfg(feature = "logging")]
502    debug!(target: "stdout", "Plugin info: b{plugin_build_number}(commit {plugin_commit})");
503
504    Ok(PluginInfo {
505        build_number: plugin_build_number,
506        commit_id: plugin_commit.to_string(),
507    })
508}
509
510/// Version info of the `wasi-nn_ggml` plugin, including the build number and the commit id.
511#[derive(Debug, Clone)]
512pub struct PluginInfo {
513    pub build_number: u64,
514    pub commit_id: String,
515}
516impl std::fmt::Display for PluginInfo {
517    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
518        write!(
519            f,
520            "wasinn-ggml plugin: b{}(commit {})",
521            self.build_number, self.commit_id
522        )
523    }
524}
525
526/// Return the current running mode.
527pub fn running_mode() -> Result<RunningMode, LlamaCoreError> {
528    #[cfg(feature = "logging")]
529    debug!(target: "stdout", "Get the running mode.");
530
531    match RUNNING_MODE.get() {
532        Some(mode) => match mode.read() {
533            Ok(mode) => Ok(*mode),
534            Err(e) => {
535                let err_msg = format!("Fail to get the underlying value of `RUNNING_MODE`. {e}");
536
537                #[cfg(feature = "logging")]
538                error!(target: "stdout", "{err_msg}");
539
540                Err(LlamaCoreError::Operation(err_msg))
541            }
542        },
543        None => {
544            let err_msg = "Fail to get the underlying value of `RUNNING_MODE`.";
545
546            #[cfg(feature = "logging")]
547            error!(target: "stdout", "{err_msg}");
548
549            Err(LlamaCoreError::Operation(err_msg.into()))
550        }
551    }
552}
553
554/// Initialize the stable-diffusion context with the given full diffusion model
555///
556/// # Arguments
557///
558/// * `model_file` - Path to the stable diffusion model file.
559///
560/// * `lora_model_dir` - Path to the Lora model directory.
561///
562/// * `controlnet_path` - Path to the controlnet model file.
563///
564/// * `controlnet_on_cpu` - Whether to run the controlnet on CPU.
565///
566/// * `clip_on_cpu` - Whether to run the CLIP on CPU.
567///
568/// * `vae_on_cpu` - Whether to run the VAE on CPU.
569///
570/// * `n_threads` - Number of threads to use.
571///
572/// * `task` - The task type to perform.
573#[allow(clippy::too_many_arguments)]
574pub fn init_sd_context_with_full_model(
575    model_file: impl AsRef<str>,
576    lora_model_dir: Option<&str>,
577    controlnet_path: Option<&str>,
578    controlnet_on_cpu: bool,
579    clip_on_cpu: bool,
580    vae_on_cpu: bool,
581    n_threads: i32,
582    task: StableDiffusionTask,
583) -> Result<(), LlamaCoreError> {
584    #[cfg(feature = "logging")]
585    info!(target: "stdout", "Initializing the stable diffusion context with the full model");
586
587    let control_net_on_cpu = match controlnet_path {
588        Some(path) if !path.is_empty() => controlnet_on_cpu,
589        _ => false,
590    };
591
592    // create the stable diffusion context for the text-to-image task
593    if task == StableDiffusionTask::Full || task == StableDiffusionTask::TextToImage {
594        let sd = SDBuidler::new(Task::TextToImage, model_file.as_ref())
595            .map_err(|e| {
596                let err_msg =
597                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
598
599                #[cfg(feature = "logging")]
600                error!(target: "stdout", "{err_msg}");
601
602                LlamaCoreError::InitContext(err_msg)
603            })?
604            .with_lora_model_dir(lora_model_dir.unwrap_or_default())
605            .map_err(|e| {
606                let err_msg =
607                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
608
609                #[cfg(feature = "logging")]
610                error!(target: "stdout", "{err_msg}");
611
612                LlamaCoreError::InitContext(err_msg)
613            })?
614            .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
615            .map_err(|e| {
616                let err_msg =
617                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
618
619                #[cfg(feature = "logging")]
620                error!(target: "stdout", "{err_msg}");
621
622                LlamaCoreError::InitContext(err_msg)
623            })?
624            .clip_on_cpu(clip_on_cpu)
625            .vae_on_cpu(vae_on_cpu)
626            .with_n_threads(n_threads)
627            .build();
628
629        #[cfg(feature = "logging")]
630        info!(target: "stdout", "sd: {:?}", &sd);
631
632        let ctx = sd.create_context().map_err(|e| {
633            let err_msg = format!("Fail to create the context. {e}");
634
635            #[cfg(feature = "logging")]
636            error!(target: "stdout", "{}", &err_msg);
637
638            LlamaCoreError::InitContext(err_msg)
639        })?;
640
641        let ctx = match ctx {
642            Context::TextToImage(ctx) => ctx,
643            _ => {
644                let err_msg = "Fail to get the context for the text-to-image task";
645
646                #[cfg(feature = "logging")]
647                error!(target: "stdout", "{err_msg}");
648
649                return Err(LlamaCoreError::InitContext(err_msg.into()));
650            }
651        };
652
653        #[cfg(feature = "logging")]
654        info!(target: "stdout", "sd text_to_image context: {:?}", &ctx);
655
656        SD_TEXT_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
657        let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_TEXT_TO_IMAGE` has already been initialized";
658
659        #[cfg(feature = "logging")]
660        error!(target: "stdout", "{err_msg}");
661
662        LlamaCoreError::InitContext(err_msg.into())
663    })?;
664
665        #[cfg(feature = "logging")]
666        info!(target: "stdout", "The stable diffusion text-to-image context has been initialized");
667    }
668
669    // create the stable diffusion context for the image-to-image task
670    if task == StableDiffusionTask::Full || task == StableDiffusionTask::ImageToImage {
671        let sd = SDBuidler::new(Task::ImageToImage, model_file.as_ref())
672            .map_err(|e| {
673                let err_msg =
674                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
675
676                #[cfg(feature = "logging")]
677                error!(target: "stdout", "{err_msg}");
678
679                LlamaCoreError::InitContext(err_msg)
680            })?
681            .with_lora_model_dir(lora_model_dir.unwrap_or_default())
682            .map_err(|e| {
683                let err_msg =
684                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
685
686                #[cfg(feature = "logging")]
687                error!(target: "stdout", "{err_msg}");
688
689                LlamaCoreError::InitContext(err_msg)
690            })?
691            .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
692            .map_err(|e| {
693                let err_msg =
694                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
695
696                #[cfg(feature = "logging")]
697                error!(target: "stdout", "{err_msg}");
698
699                LlamaCoreError::InitContext(err_msg)
700            })?
701            .clip_on_cpu(clip_on_cpu)
702            .vae_on_cpu(vae_on_cpu)
703            .with_n_threads(n_threads)
704            .build();
705
706        #[cfg(feature = "logging")]
707        info!(target: "stdout", "sd: {:?}", &sd);
708
709        let ctx = sd.create_context().map_err(|e| {
710            let err_msg = format!("Fail to create the context. {e}");
711
712            #[cfg(feature = "logging")]
713            error!(target: "stdout", "{}", &err_msg);
714
715            LlamaCoreError::InitContext(err_msg)
716        })?;
717
718        let ctx = match ctx {
719            Context::ImageToImage(ctx) => ctx,
720            _ => {
721                let err_msg = "Fail to get the context for the image-to-image task";
722
723                #[cfg(feature = "logging")]
724                error!(target: "stdout", "{err_msg}");
725
726                return Err(LlamaCoreError::InitContext(err_msg.into()));
727            }
728        };
729
730        #[cfg(feature = "logging")]
731        info!(target: "stdout", "sd image_to_image context: {:?}", &ctx);
732
733        SD_IMAGE_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
734            let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_IMAGE_TO_IMAGE` has already been initialized";
735
736            #[cfg(feature = "logging")]
737            error!(target: "stdout", "{err_msg}");
738
739            LlamaCoreError::InitContext(err_msg.into())
740        })?;
741
742        #[cfg(feature = "logging")]
743        info!(target: "stdout", "The stable diffusion image-to-image context has been initialized");
744    }
745
746    Ok(())
747}
748
749/// Initialize the stable-diffusion context with the given standalone diffusion model
750///
751/// # Arguments
752///
753/// * `model_file` - Path to the standalone diffusion model file.
754///
755/// * `vae` - Path to the VAE model file.
756///
757/// * `clip_l` - Path to the CLIP model file.
758///
759/// * `t5xxl` - Path to the T5-XXL model file.
760///
761/// * `lora_model_dir` - Path to the Lora model directory.
762///
763/// * `controlnet_path` - Path to the controlnet model file.
764///
765/// * `controlnet_on_cpu` - Whether to run the controlnet on CPU.
766///
767/// * `clip_on_cpu` - Whether to run the CLIP on CPU.
768///
769/// * `vae_on_cpu` - Whether to run the VAE on CPU.
770///
771/// * `n_threads` - Number of threads to use.
772///
773/// * `task` - The task type to perform.
774#[allow(clippy::too_many_arguments)]
775pub fn init_sd_context_with_standalone_model(
776    model_file: impl AsRef<str>,
777    vae: impl AsRef<str>,
778    clip_l: impl AsRef<str>,
779    t5xxl: impl AsRef<str>,
780    lora_model_dir: Option<&str>,
781    controlnet_path: Option<&str>,
782    controlnet_on_cpu: bool,
783    clip_on_cpu: bool,
784    vae_on_cpu: bool,
785    n_threads: i32,
786    task: StableDiffusionTask,
787) -> Result<(), LlamaCoreError> {
788    #[cfg(feature = "logging")]
789    info!(target: "stdout", "Initializing the stable diffusion context with the standalone diffusion model");
790
791    let control_net_on_cpu = match controlnet_path {
792        Some(path) if !path.is_empty() => controlnet_on_cpu,
793        _ => false,
794    };
795
796    // create the stable diffusion context for the text-to-image task
797    if task == StableDiffusionTask::Full || task == StableDiffusionTask::TextToImage {
798        let sd = SDBuidler::new_with_standalone_model(Task::TextToImage, model_file.as_ref())
799            .map_err(|e| {
800                let err_msg =
801                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
802
803                #[cfg(feature = "logging")]
804                error!(target: "stdout", "{err_msg}");
805
806                LlamaCoreError::InitContext(err_msg)
807            })?
808            .with_vae_path(vae.as_ref())
809            .map_err(|e| {
810                let err_msg =
811                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
812
813                #[cfg(feature = "logging")]
814                error!(target: "stdout", "{err_msg}");
815
816                LlamaCoreError::InitContext(err_msg)
817            })?
818            .with_clip_l_path(clip_l.as_ref())
819            .map_err(|e| {
820                let err_msg =
821                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
822
823                #[cfg(feature = "logging")]
824                error!(target: "stdout", "{err_msg}");
825
826                LlamaCoreError::InitContext(err_msg)
827            })?
828            .with_t5xxl_path(t5xxl.as_ref())
829            .map_err(|e| {
830                let err_msg =
831                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
832
833                #[cfg(feature = "logging")]
834                error!(target: "stdout", "{err_msg}");
835
836                LlamaCoreError::InitContext(err_msg)
837            })?
838            .with_lora_model_dir(lora_model_dir.unwrap_or_default())
839            .map_err(|e| {
840                let err_msg =
841                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
842
843                #[cfg(feature = "logging")]
844                error!(target: "stdout", "{err_msg}");
845
846                LlamaCoreError::InitContext(err_msg)
847            })?
848            .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
849            .map_err(|e| {
850                let err_msg =
851                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
852
853                #[cfg(feature = "logging")]
854                error!(target: "stdout", "{err_msg}");
855
856                LlamaCoreError::InitContext(err_msg)
857            })?
858            .clip_on_cpu(clip_on_cpu)
859            .vae_on_cpu(vae_on_cpu)
860            .with_n_threads(n_threads)
861            .build();
862
863        #[cfg(feature = "logging")]
864        info!(target: "stdout", "sd: {:?}", &sd);
865
866        let ctx = sd.create_context().map_err(|e| {
867            let err_msg = format!("Fail to create the context. {e}");
868
869            #[cfg(feature = "logging")]
870            error!(target: "stdout", "{}", &err_msg);
871
872            LlamaCoreError::InitContext(err_msg)
873        })?;
874
875        let ctx = match ctx {
876            Context::TextToImage(ctx) => ctx,
877            _ => {
878                let err_msg = "Fail to get the context for the text-to-image task";
879
880                #[cfg(feature = "logging")]
881                error!(target: "stdout", "{err_msg}");
882
883                return Err(LlamaCoreError::InitContext(err_msg.into()));
884            }
885        };
886
887        #[cfg(feature = "logging")]
888        info!(target: "stdout", "sd text_to_image context: {:?}", &ctx);
889
890        SD_TEXT_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
891            let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_TEXT_TO_IMAGE` has already been initialized";
892
893            #[cfg(feature = "logging")]
894            error!(target: "stdout", "{err_msg}");
895
896            LlamaCoreError::InitContext(err_msg.into())
897        })?;
898
899        #[cfg(feature = "logging")]
900        info!(target: "stdout", "The stable diffusion text-to-image context has been initialized");
901    }
902
903    // create the stable diffusion context for the image-to-image task
904    if task == StableDiffusionTask::Full || task == StableDiffusionTask::ImageToImage {
905        let sd = SDBuidler::new_with_standalone_model(Task::ImageToImage, model_file.as_ref())
906            .map_err(|e| {
907                let err_msg =
908                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
909
910                #[cfg(feature = "logging")]
911                error!(target: "stdout", "{err_msg}");
912
913                LlamaCoreError::InitContext(err_msg)
914            })?
915            .with_vae_path(vae.as_ref())
916            .map_err(|e| {
917                let err_msg =
918                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
919
920                #[cfg(feature = "logging")]
921                error!(target: "stdout", "{err_msg}");
922
923                LlamaCoreError::InitContext(err_msg)
924            })?
925            .with_clip_l_path(clip_l.as_ref())
926            .map_err(|e| {
927                let err_msg =
928                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
929
930                #[cfg(feature = "logging")]
931                error!(target: "stdout", "{err_msg}");
932
933                LlamaCoreError::InitContext(err_msg)
934            })?
935            .with_t5xxl_path(t5xxl.as_ref())
936            .map_err(|e| {
937                let err_msg =
938                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
939
940                #[cfg(feature = "logging")]
941                error!(target: "stdout", "{err_msg}");
942
943                LlamaCoreError::InitContext(err_msg)
944            })?
945            .with_lora_model_dir(lora_model_dir.unwrap_or_default())
946            .map_err(|e| {
947                let err_msg =
948                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
949
950                #[cfg(feature = "logging")]
951                error!(target: "stdout", "{err_msg}");
952
953                LlamaCoreError::InitContext(err_msg)
954            })?
955            .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
956            .map_err(|e| {
957                let err_msg =
958                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
959
960                #[cfg(feature = "logging")]
961                error!(target: "stdout", "{err_msg}");
962
963                LlamaCoreError::InitContext(err_msg)
964            })?
965            .clip_on_cpu(clip_on_cpu)
966            .vae_on_cpu(vae_on_cpu)
967            .with_n_threads(n_threads)
968            .build();
969
970        #[cfg(feature = "logging")]
971        info!(target: "stdout", "sd: {:?}", &sd);
972
973        let ctx = sd.create_context().map_err(|e| {
974            let err_msg = format!("Fail to create the context. {e}");
975
976            #[cfg(feature = "logging")]
977            error!(target: "stdout", "{}", &err_msg);
978
979            LlamaCoreError::InitContext(err_msg)
980        })?;
981
982        let ctx = match ctx {
983            Context::ImageToImage(ctx) => ctx,
984            _ => {
985                let err_msg = "Fail to get the context for the image-to-image task";
986
987                #[cfg(feature = "logging")]
988                error!(target: "stdout", "{err_msg}");
989
990                return Err(LlamaCoreError::InitContext(err_msg.into()));
991            }
992        };
993
994        #[cfg(feature = "logging")]
995        info!(target: "stdout", "sd image_to_image context: {:?}", &ctx);
996
997        SD_IMAGE_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
998        let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_IMAGE_TO_IMAGE` has already been initialized";
999
1000        #[cfg(feature = "logging")]
1001        error!(target: "stdout", "{err_msg}");
1002
1003        LlamaCoreError::InitContext(err_msg.into())
1004    })?;
1005
1006        #[cfg(feature = "logging")]
1007        info!(target: "stdout", "The stable diffusion image-to-image context has been initialized");
1008    }
1009
1010    Ok(())
1011}
1012
1013/// The task type of the stable diffusion context
1014#[derive(Clone, Debug, Copy, PartialEq, Eq)]
1015pub enum StableDiffusionTask {
1016    /// `text_to_image` context
1017    TextToImage,
1018    /// `image_to_image` context
1019    ImageToImage,
1020    /// Both `text_to_image` and `image_to_image` contexts
1021    Full,
1022}
1023
1024/// Initialize the whisper context
1025#[cfg(feature = "whisper")]
1026pub fn init_whisper_context(whisper_metadata: &WhisperMetadata) -> Result<(), LlamaCoreError> {
1027    // create and initialize the audio context
1028    let graph = GraphBuilder::new(EngineType::Whisper)?
1029        .with_config(whisper_metadata.clone())?
1030        .use_cpu()
1031        .build_from_files([&whisper_metadata.model_path])?;
1032
1033    match AUDIO_GRAPH.get() {
1034        Some(mutex_graph) => {
1035            #[cfg(feature = "logging")]
1036            info!(target: "stdout", "Re-initialize the audio context");
1037
1038            match mutex_graph.lock() {
1039                Ok(mut locked_graph) => *locked_graph = graph,
1040                Err(e) => {
1041                    let err_msg = format!("Failed to lock the graph. Reason: {e}");
1042
1043                    #[cfg(feature = "logging")]
1044                    error!(target: "stdout", "{err_msg}");
1045
1046                    return Err(LlamaCoreError::InitContext(err_msg));
1047                }
1048            }
1049        }
1050        None => {
1051            #[cfg(feature = "logging")]
1052            info!(target: "stdout", "Initialize the audio context");
1053
1054            AUDIO_GRAPH.set(Mutex::new(graph)).map_err(|_| {
1055                let err_msg = "Failed to initialize the audio context. Reason: The `AUDIO_GRAPH` has already been initialized";
1056
1057                #[cfg(feature = "logging")]
1058                error!(target: "stdout", "{err_msg}");
1059
1060                LlamaCoreError::InitContext(err_msg.into())
1061            })?;
1062        }
1063    }
1064
1065    #[cfg(feature = "logging")]
1066    info!(target: "stdout", "The audio context has been initialized");
1067
1068    Ok(())
1069}
1070
1071/// Initialize the piper context
1072///
1073/// # Arguments
1074///
1075/// * `voice_model` - Path to the voice model file.
1076///
1077/// * `voice_config` - Path to the voice config file.
1078///
1079/// * `espeak_ng_data` - Path to the espeak-ng data directory.
1080///
1081pub fn init_piper_context(
1082    piper_metadata: &PiperMetadata,
1083    voice_model: impl AsRef<Path>,
1084    voice_config: impl AsRef<Path>,
1085    espeak_ng_data: impl AsRef<Path>,
1086) -> Result<(), LlamaCoreError> {
1087    #[cfg(feature = "logging")]
1088    info!(target: "stdout", "Initializing the piper context");
1089
1090    let config = serde_json::json!({
1091        "model": voice_model.as_ref().to_owned(),
1092        "config": voice_config.as_ref().to_owned(),
1093        "espeak_data": espeak_ng_data.as_ref().to_owned(),
1094    });
1095
1096    // create and initialize the audio context
1097    let graph = GraphBuilder::new(EngineType::Piper)?
1098        .with_config(piper_metadata.clone())?
1099        .use_cpu()
1100        .build_from_buffer([config.to_string()])?;
1101
1102    PIPER_GRAPH.set(Mutex::new(graph)).map_err(|_| {
1103            let err_msg = "Failed to initialize the piper context. Reason: The `PIPER_GRAPH` has already been initialized";
1104
1105            #[cfg(feature = "logging")]
1106            error!(target: "stdout", "{err_msg}");
1107
1108            LlamaCoreError::InitContext(err_msg.into())
1109        })?;
1110
1111    #[cfg(feature = "logging")]
1112    info!(target: "stdout", "The piper context has been initialized");
1113
1114    Ok(())
1115}