llama_core/
lib.rs

1//! Llama Core, abbreviated as `llama-core`, defines a set of APIs. Developers can utilize these APIs to build applications based on large models, such as chatbots, RAG, and more.
2
3#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
4
5#[cfg(feature = "logging")]
6#[macro_use]
7extern crate log;
8
9pub mod audio;
10pub mod chat;
11pub mod completions;
12pub mod embeddings;
13pub mod error;
14pub mod files;
15pub mod graph;
16pub mod images;
17pub mod metadata;
18pub mod models;
19#[cfg(feature = "rag")]
20#[cfg_attr(docsrs, doc(cfg(feature = "rag")))]
21pub mod rag;
22#[cfg(feature = "search")]
23#[cfg_attr(docsrs, doc(cfg(feature = "search")))]
24pub mod search;
25pub mod tts;
26pub mod utils;
27
28pub use error::LlamaCoreError;
29pub use graph::{EngineType, Graph, GraphBuilder};
30#[cfg(feature = "whisper")]
31use metadata::whisper::WhisperMetadata;
32pub use metadata::{
33    ggml::{GgmlMetadata, GgmlTtsMetadata},
34    piper::PiperMetadata,
35    BaseMetadata,
36};
37use once_cell::sync::OnceCell;
38use std::{
39    collections::HashMap,
40    path::Path,
41    sync::{Mutex, RwLock},
42};
43use utils::{get_output_buffer, RunningMode};
44use wasmedge_stable_diffusion::*;
45
46// key: model_name, value: Graph
47pub(crate) static CHAT_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlMetadata>>>> =
48    OnceCell::new();
49// key: model_name, value: Graph
50pub(crate) static EMBEDDING_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlMetadata>>>> =
51    OnceCell::new();
52// key: model_name, value: Graph
53pub(crate) static TTS_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlTtsMetadata>>>> =
54    OnceCell::new();
55// cache bytes for decoding utf8
56pub(crate) static CACHED_UTF8_ENCODINGS: OnceCell<Mutex<Vec<u8>>> = OnceCell::new();
57// running mode
58pub(crate) static RUNNING_MODE: OnceCell<RwLock<RunningMode>> = OnceCell::new();
59// stable diffusion context for the text-to-image task
60pub(crate) static SD_TEXT_TO_IMAGE: OnceCell<Mutex<TextToImage>> = OnceCell::new();
61// stable diffusion context for the image-to-image task
62pub(crate) static SD_IMAGE_TO_IMAGE: OnceCell<Mutex<ImageToImage>> = OnceCell::new();
63// context for the audio task
64#[cfg(feature = "whisper")]
65pub(crate) static AUDIO_GRAPH: OnceCell<Mutex<Graph<WhisperMetadata>>> = OnceCell::new();
66// context for the piper task
67pub(crate) static PIPER_GRAPH: OnceCell<Mutex<Graph<PiperMetadata>>> = OnceCell::new();
68
69pub(crate) const MAX_BUFFER_SIZE: usize = 2usize.pow(14) * 15 + 128;
70pub(crate) const OUTPUT_TENSOR: usize = 0;
71const PLUGIN_VERSION: usize = 1;
72
73/// The directory for storing the archives in wasm virtual file system.
74pub const ARCHIVES_DIR: &str = "archives";
75
76/// Initialize the ggml context
77pub fn init_ggml_chat_context(metadata_for_chats: &[GgmlMetadata]) -> Result<(), LlamaCoreError> {
78    #[cfg(feature = "logging")]
79    info!(target: "stdout", "Initializing the core context");
80
81    if metadata_for_chats.is_empty() {
82        let err_msg = "The metadata for chat models is empty";
83
84        #[cfg(feature = "logging")]
85        error!(target: "stdout", "{}", err_msg);
86
87        return Err(LlamaCoreError::InitContext(err_msg.into()));
88    }
89
90    let mut chat_graphs = HashMap::new();
91    for metadata in metadata_for_chats {
92        let graph = Graph::new(metadata.clone())?;
93
94        chat_graphs.insert(graph.name().to_string(), graph);
95    }
96    CHAT_GRAPHS.set(Mutex::new(chat_graphs)).map_err(|_| {
97            let err_msg = "Failed to initialize the core context. Reason: The `CHAT_GRAPHS` has already been initialized";
98
99            #[cfg(feature = "logging")]
100            error!(target: "stdout", "{}", err_msg);
101
102            LlamaCoreError::InitContext(err_msg.into())
103        })?;
104
105    // set running mode
106    let running_mode = RunningMode::CHAT;
107    match RUNNING_MODE.get() {
108        Some(mode) => {
109            let mut mode = mode.write().unwrap();
110            *mode |= running_mode;
111        }
112        None => {
113            RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
114                let err_msg = "Failed to initialize the chat context. Reason: The `RUNNING_MODE` has already been initialized";
115
116                #[cfg(feature = "logging")]
117                error!(target: "stdout", "{}", err_msg);
118
119                LlamaCoreError::InitContext(err_msg.into())
120            })?;
121        }
122    }
123
124    Ok(())
125}
126
127/// Initialize the ggml context
128pub fn init_ggml_embeddings_context(
129    metadata_for_embeddings: &[GgmlMetadata],
130) -> Result<(), LlamaCoreError> {
131    #[cfg(feature = "logging")]
132    info!(target: "stdout", "Initializing the embeddings context");
133
134    if metadata_for_embeddings.is_empty() {
135        let err_msg = "The metadata for chat models is empty";
136
137        #[cfg(feature = "logging")]
138        error!(target: "stdout", "{}", err_msg);
139
140        return Err(LlamaCoreError::InitContext(err_msg.into()));
141    }
142
143    let mut embedding_graphs = HashMap::new();
144    for metadata in metadata_for_embeddings {
145        let graph = Graph::new(metadata.clone())?;
146
147        embedding_graphs.insert(graph.name().to_string(), graph);
148    }
149    EMBEDDING_GRAPHS
150            .set(Mutex::new(embedding_graphs))
151            .map_err(|_| {
152                let err_msg = "Failed to initialize the core context. Reason: The `EMBEDDING_GRAPHS` has already been initialized";
153
154                #[cfg(feature = "logging")]
155                error!(target: "stdout", "{}", err_msg);
156
157                LlamaCoreError::InitContext(err_msg.into())
158            })?;
159
160    // set running mode
161    let running_mode = RunningMode::EMBEDDINGS;
162    match RUNNING_MODE.get() {
163        Some(mode) => {
164            let mut mode = mode.write().unwrap();
165            *mode |= running_mode;
166        }
167        None => {
168            RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
169                let err_msg = "Failed to initialize the embeddings context. Reason: The `RUNNING_MODE` has already been initialized";
170
171                #[cfg(feature = "logging")]
172                error!(target: "stdout", "{}", err_msg);
173
174                LlamaCoreError::InitContext(err_msg.into())
175            })?;
176        }
177    }
178
179    Ok(())
180}
181
182/// Initialize the ggml context for RAG scenarios.
183#[cfg(feature = "rag")]
184pub fn init_ggml_rag_context(
185    metadata_for_chats: &[GgmlMetadata],
186    metadata_for_embeddings: &[GgmlMetadata],
187) -> Result<(), LlamaCoreError> {
188    #[cfg(feature = "logging")]
189    info!(target: "stdout", "Initializing the core context for RAG scenarios");
190
191    // chat models
192    if metadata_for_chats.is_empty() {
193        let err_msg = "The metadata for chat models is empty";
194
195        #[cfg(feature = "logging")]
196        error!(target: "stdout", "{}", err_msg);
197
198        return Err(LlamaCoreError::InitContext(err_msg.into()));
199    }
200    let mut chat_graphs = HashMap::new();
201    for metadata in metadata_for_chats {
202        let graph = Graph::new(metadata.clone())?;
203
204        chat_graphs.insert(graph.name().to_string(), graph);
205    }
206    CHAT_GRAPHS.set(Mutex::new(chat_graphs)).map_err(|_| {
207        let err_msg = "Failed to initialize the core context. Reason: The `CHAT_GRAPHS` has already been initialized";
208
209        #[cfg(feature = "logging")]
210        error!(target: "stdout", "{}", err_msg);
211
212        LlamaCoreError::InitContext(err_msg.into())
213    })?;
214
215    // embedding models
216    if metadata_for_embeddings.is_empty() {
217        let err_msg = "The metadata for embeddings is empty";
218
219        #[cfg(feature = "logging")]
220        error!(target: "stdout", "{}", err_msg);
221
222        return Err(LlamaCoreError::InitContext(err_msg.into()));
223    }
224    let mut embedding_graphs = HashMap::new();
225    for metadata in metadata_for_embeddings {
226        let graph = Graph::new(metadata.clone())?;
227
228        embedding_graphs.insert(graph.name().to_string(), graph);
229    }
230    EMBEDDING_GRAPHS
231        .set(Mutex::new(embedding_graphs))
232        .map_err(|_| {
233            let err_msg = "Failed to initialize the core context. Reason: The `EMBEDDING_GRAPHS` has already been initialized";
234
235            #[cfg(feature = "logging")]
236            error!(target: "stdout", "{}", err_msg);
237
238            LlamaCoreError::InitContext(err_msg.into())
239        })?;
240
241    // set running mode
242    let running_mode = RunningMode::RAG;
243    match RUNNING_MODE.get() {
244        Some(mode) => {
245            let mut mode = mode.write().unwrap();
246            *mode |= running_mode;
247        }
248        None => {
249            RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
250                    let err_msg = "Failed to initialize the rag context. Reason: The `RUNNING_MODE` has already been initialized";
251
252                    #[cfg(feature = "logging")]
253                    error!(target: "stdout", "{}", err_msg);
254
255                    LlamaCoreError::InitContext(err_msg.into())
256                })?;
257        }
258    }
259
260    Ok(())
261}
262
263/// Initialize the ggml context for TTS scenarios.
264pub fn init_ggml_tts_context(metadata_for_tts: &[GgmlTtsMetadata]) -> Result<(), LlamaCoreError> {
265    #[cfg(feature = "logging")]
266    info!(target: "stdout", "Initializing the TTS context");
267
268    if metadata_for_tts.is_empty() {
269        let err_msg = "The metadata for tts models is empty";
270
271        #[cfg(feature = "logging")]
272        error!(target: "stdout", "{}", err_msg);
273
274        return Err(LlamaCoreError::InitContext(err_msg.into()));
275    }
276
277    let mut tts_graphs = HashMap::new();
278    for metadata in metadata_for_tts {
279        let graph = Graph::new(metadata.clone())?;
280
281        tts_graphs.insert(graph.name().to_string(), graph);
282    }
283    TTS_GRAPHS.set(Mutex::new(tts_graphs)).map_err(|_| {
284        let err_msg = "Failed to initialize the core context. Reason: The `TTS_GRAPHS` has already been initialized";
285
286        #[cfg(feature = "logging")]
287        error!(target: "stdout", "{}", err_msg);
288
289        LlamaCoreError::InitContext(err_msg.into())
290    })?;
291
292    // set running mode
293    let running_mode = RunningMode::TTS;
294    match RUNNING_MODE.get() {
295        Some(mode) => {
296            let mut mode = mode.write().unwrap();
297            *mode |= running_mode;
298        }
299        None => {
300            RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
301                let err_msg = "Failed to initialize the embeddings context. Reason: The `RUNNING_MODE` has already been initialized";
302
303                #[cfg(feature = "logging")]
304                error!(target: "stdout", "{}", err_msg);
305
306                LlamaCoreError::InitContext(err_msg.into())
307            })?;
308        }
309    }
310
311    Ok(())
312}
313
314/// Get the plugin info
315///
316/// Note that it is required to call `init_core_context` before calling this function.
317pub fn get_plugin_info() -> Result<PluginInfo, LlamaCoreError> {
318    #[cfg(feature = "logging")]
319    debug!(target: "stdout", "Getting the plugin info");
320
321    let running_mode = running_mode()?;
322
323    if running_mode.contains(RunningMode::CHAT) || running_mode.contains(RunningMode::RAG) {
324        let chat_graphs = match CHAT_GRAPHS.get() {
325            Some(chat_graphs) => chat_graphs,
326            None => {
327                let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
328
329                #[cfg(feature = "logging")]
330                error!(target: "stdout", "{}", err_msg);
331
332                return Err(LlamaCoreError::Operation(err_msg.into()));
333            }
334        };
335
336        let chat_graphs = chat_graphs.lock().map_err(|e| {
337            let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
338
339            #[cfg(feature = "logging")]
340            error!(target: "stdout", "{}", &err_msg);
341
342            LlamaCoreError::Operation(err_msg)
343        })?;
344
345        let graph = match chat_graphs.values().next() {
346            Some(graph) => graph,
347            None => {
348                let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
349
350                #[cfg(feature = "logging")]
351                error!(target: "stdout", "{}", err_msg);
352
353                return Err(LlamaCoreError::Operation(err_msg.into()));
354            }
355        };
356
357        get_plugin_info_by_graph(graph)
358    } else if running_mode.contains(RunningMode::EMBEDDINGS) {
359        let embedding_graphs = match EMBEDDING_GRAPHS.get() {
360            Some(embedding_graphs) => embedding_graphs,
361            None => {
362                let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
363
364                #[cfg(feature = "logging")]
365                error!(target: "stdout", "{}", err_msg);
366
367                return Err(LlamaCoreError::Operation(err_msg.into()));
368            }
369        };
370
371        let embedding_graphs = embedding_graphs.lock().map_err(|e| {
372            let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {}", e);
373
374            #[cfg(feature = "logging")]
375            error!(target: "stdout", "{}", &err_msg);
376
377            LlamaCoreError::Operation(err_msg)
378        })?;
379
380        let graph = match embedding_graphs.values().next() {
381            Some(graph) => graph,
382            None => {
383                let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
384
385                #[cfg(feature = "logging")]
386                error!(target: "stdout", "{}", err_msg);
387
388                return Err(LlamaCoreError::Operation(err_msg.into()));
389            }
390        };
391
392        get_plugin_info_by_graph(graph)
393    } else if running_mode.contains(RunningMode::TTS) {
394        let tts_graphs = match TTS_GRAPHS.get() {
395            Some(tts_graphs) => tts_graphs,
396            None => {
397                let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
398
399                #[cfg(feature = "logging")]
400                error!(target: "stdout", "{}", err_msg);
401
402                return Err(LlamaCoreError::Operation(err_msg.into()));
403            }
404        };
405
406        let tts_graphs = tts_graphs.lock().map_err(|e| {
407            let err_msg = format!("Fail to acquire the lock of `TTS_GRAPHS`. {}", e);
408
409            #[cfg(feature = "logging")]
410            error!(target: "stdout", "{}", &err_msg);
411
412            LlamaCoreError::Operation(err_msg)
413        })?;
414
415        let graph = match tts_graphs.values().next() {
416            Some(graph) => graph,
417            None => {
418                let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
419
420                #[cfg(feature = "logging")]
421                error!(target: "stdout", "{}", err_msg);
422
423                return Err(LlamaCoreError::Operation(err_msg.into()));
424            }
425        };
426
427        get_plugin_info_by_graph(graph)
428    } else {
429        let err_msg = "RUNNING_MODE is not set";
430
431        #[cfg(feature = "logging")]
432        error!(target: "stdout", "{}", err_msg);
433
434        Err(LlamaCoreError::Operation(err_msg.into()))
435    }
436}
437
438fn get_plugin_info_by_graph<M: BaseMetadata + serde::Serialize + Clone + Default>(
439    graph: &Graph<M>,
440) -> Result<PluginInfo, LlamaCoreError> {
441    #[cfg(feature = "logging")]
442    debug!(target: "stdout", "Getting the plugin info by the graph named {}", graph.name());
443
444    // get the plugin metadata
445    let output_buffer = get_output_buffer(graph, PLUGIN_VERSION)?;
446    let metadata: serde_json::Value = serde_json::from_slice(&output_buffer[..]).map_err(|e| {
447        let err_msg = format!("Fail to deserialize the plugin metadata. {}", e);
448
449        #[cfg(feature = "logging")]
450        error!(target: "stdout", "{}", &err_msg);
451
452        LlamaCoreError::Operation(err_msg)
453    })?;
454
455    // get build number of the plugin
456    let plugin_build_number = match metadata.get("llama_build_number") {
457        Some(value) => match value.as_u64() {
458            Some(number) => number,
459            None => {
460                let err_msg = "Failed to convert the build number of the plugin to u64";
461
462                #[cfg(feature = "logging")]
463                error!(target: "stdout", "{}", err_msg);
464
465                return Err(LlamaCoreError::Operation(err_msg.into()));
466            }
467        },
468        None => {
469            let err_msg = "Metadata does not have the field `llama_build_number`.";
470
471            #[cfg(feature = "logging")]
472            error!(target: "stdout", "{}", err_msg);
473
474            return Err(LlamaCoreError::Operation(err_msg.into()));
475        }
476    };
477
478    // get commit id of the plugin
479    let plugin_commit = match metadata.get("llama_commit") {
480        Some(value) => match value.as_str() {
481            Some(commit) => commit,
482            None => {
483                let err_msg = "Failed to convert the commit id of the plugin to string";
484
485                #[cfg(feature = "logging")]
486                error!(target: "stdout", "{}", err_msg);
487
488                return Err(LlamaCoreError::Operation(err_msg.into()));
489            }
490        },
491        None => {
492            let err_msg = "Metadata does not have the field `llama_commit`.";
493
494            #[cfg(feature = "logging")]
495            error!(target: "stdout", "{}", err_msg);
496
497            return Err(LlamaCoreError::Operation(err_msg.into()));
498        }
499    };
500
501    #[cfg(feature = "logging")]
502    debug!(target: "stdout", "Plugin info: b{}(commit {})", plugin_build_number, plugin_commit);
503
504    Ok(PluginInfo {
505        build_number: plugin_build_number,
506        commit_id: plugin_commit.to_string(),
507    })
508}
509
510/// Version info of the `wasi-nn_ggml` plugin, including the build number and the commit id.
511#[derive(Debug, Clone)]
512pub struct PluginInfo {
513    pub build_number: u64,
514    pub commit_id: String,
515}
516impl std::fmt::Display for PluginInfo {
517    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
518        write!(
519            f,
520            "wasinn-ggml plugin: b{}(commit {})",
521            self.build_number, self.commit_id
522        )
523    }
524}
525
526/// Return the current running mode.
527pub fn running_mode() -> Result<RunningMode, LlamaCoreError> {
528    #[cfg(feature = "logging")]
529    debug!(target: "stdout", "Get the running mode.");
530
531    match RUNNING_MODE.get() {
532        Some(mode) => match mode.read() {
533            Ok(mode) => Ok(*mode),
534            Err(e) => {
535                let err_msg = format!("Fail to get the underlying value of `RUNNING_MODE`. {}", e);
536
537                #[cfg(feature = "logging")]
538                error!(target: "stdout", "{}", err_msg);
539
540                Err(LlamaCoreError::Operation(err_msg))
541            }
542        },
543        None => {
544            let err_msg = "Fail to get the underlying value of `RUNNING_MODE`.";
545
546            #[cfg(feature = "logging")]
547            error!(target: "stdout", "{}", err_msg);
548
549            Err(LlamaCoreError::Operation(err_msg.into()))
550        }
551    }
552}
553
554/// Initialize the stable-diffusion context with the given full diffusion model
555///
556/// # Arguments
557///
558/// * `model_file` - Path to the stable diffusion model file.
559///
560/// * `lora_model_dir` - Path to the Lora model directory.
561///
562/// * `controlnet_path` - Path to the controlnet model file.
563///
564/// * `controlnet_on_cpu` - Whether to run the controlnet on CPU.
565///
566/// * `clip_on_cpu` - Whether to run the CLIP on CPU.
567///
568/// * `vae_on_cpu` - Whether to run the VAE on CPU.
569///
570/// * `n_threads` - Number of threads to use.
571///
572/// * `task` - The task type to perform.
573#[allow(clippy::too_many_arguments)]
574pub fn init_sd_context_with_full_model(
575    model_file: impl AsRef<str>,
576    lora_model_dir: Option<&str>,
577    controlnet_path: Option<&str>,
578    controlnet_on_cpu: bool,
579    clip_on_cpu: bool,
580    vae_on_cpu: bool,
581    n_threads: i32,
582    task: StableDiffusionTask,
583) -> Result<(), LlamaCoreError> {
584    #[cfg(feature = "logging")]
585    info!(target: "stdout", "Initializing the stable diffusion context with the full model");
586
587    let control_net_on_cpu = match controlnet_path {
588        Some(path) if !path.is_empty() => controlnet_on_cpu,
589        _ => false,
590    };
591
592    // create the stable diffusion context for the text-to-image task
593    if task == StableDiffusionTask::Full || task == StableDiffusionTask::TextToImage {
594        let sd = SDBuidler::new(Task::TextToImage, model_file.as_ref())
595            .map_err(|e| {
596                let err_msg = format!(
597                    "Failed to initialize the stable diffusion context. Reason: {}",
598                    e
599                );
600
601                #[cfg(feature = "logging")]
602                error!(target: "stdout", "{}", err_msg);
603
604                LlamaCoreError::InitContext(err_msg)
605            })?
606            .with_lora_model_dir(lora_model_dir.unwrap_or_default())
607            .map_err(|e| {
608                let err_msg = format!(
609                    "Failed to initialize the stable diffusion context. Reason: {}",
610                    e
611                );
612
613                #[cfg(feature = "logging")]
614                error!(target: "stdout", "{}", err_msg);
615
616                LlamaCoreError::InitContext(err_msg)
617            })?
618            .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
619            .map_err(|e| {
620                let err_msg = format!(
621                    "Failed to initialize the stable diffusion context. Reason: {}",
622                    e
623                );
624
625                #[cfg(feature = "logging")]
626                error!(target: "stdout", "{}", err_msg);
627
628                LlamaCoreError::InitContext(err_msg)
629            })?
630            .clip_on_cpu(clip_on_cpu)
631            .vae_on_cpu(vae_on_cpu)
632            .with_n_threads(n_threads)
633            .build();
634
635        #[cfg(feature = "logging")]
636        info!(target: "stdout", "sd: {:?}", &sd);
637
638        let ctx = sd.create_context().map_err(|e| {
639            let err_msg = format!("Fail to create the context. {}", e);
640
641            #[cfg(feature = "logging")]
642            error!(target: "stdout", "{}", &err_msg);
643
644            LlamaCoreError::InitContext(err_msg)
645        })?;
646
647        let ctx = match ctx {
648            Context::TextToImage(ctx) => ctx,
649            _ => {
650                let err_msg = "Fail to get the context for the text-to-image task";
651
652                #[cfg(feature = "logging")]
653                error!(target: "stdout", "{}", err_msg);
654
655                return Err(LlamaCoreError::InitContext(err_msg.into()));
656            }
657        };
658
659        #[cfg(feature = "logging")]
660        info!(target: "stdout", "sd text_to_image context: {:?}", &ctx);
661
662        SD_TEXT_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
663        let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_TEXT_TO_IMAGE` has already been initialized";
664
665        #[cfg(feature = "logging")]
666        error!(target: "stdout", "{}", err_msg);
667
668        LlamaCoreError::InitContext(err_msg.into())
669    })?;
670
671        #[cfg(feature = "logging")]
672        info!(target: "stdout", "The stable diffusion text-to-image context has been initialized");
673    }
674
675    // create the stable diffusion context for the image-to-image task
676    if task == StableDiffusionTask::Full || task == StableDiffusionTask::ImageToImage {
677        let sd = SDBuidler::new(Task::ImageToImage, model_file.as_ref())
678            .map_err(|e| {
679                let err_msg = format!(
680                    "Failed to initialize the stable diffusion context. Reason: {}",
681                    e
682                );
683
684                #[cfg(feature = "logging")]
685                error!(target: "stdout", "{}", err_msg);
686
687                LlamaCoreError::InitContext(err_msg)
688            })?
689            .with_lora_model_dir(lora_model_dir.unwrap_or_default())
690            .map_err(|e| {
691                let err_msg = format!(
692                    "Failed to initialize the stable diffusion context. Reason: {}",
693                    e
694                );
695
696                #[cfg(feature = "logging")]
697                error!(target: "stdout", "{}", err_msg);
698
699                LlamaCoreError::InitContext(err_msg)
700            })?
701            .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
702            .map_err(|e| {
703                let err_msg = format!(
704                    "Failed to initialize the stable diffusion context. Reason: {}",
705                    e
706                );
707
708                #[cfg(feature = "logging")]
709                error!(target: "stdout", "{}", err_msg);
710
711                LlamaCoreError::InitContext(err_msg)
712            })?
713            .clip_on_cpu(clip_on_cpu)
714            .vae_on_cpu(vae_on_cpu)
715            .with_n_threads(n_threads)
716            .build();
717
718        #[cfg(feature = "logging")]
719        info!(target: "stdout", "sd: {:?}", &sd);
720
721        let ctx = sd.create_context().map_err(|e| {
722            let err_msg = format!("Fail to create the context. {}", e);
723
724            #[cfg(feature = "logging")]
725            error!(target: "stdout", "{}", &err_msg);
726
727            LlamaCoreError::InitContext(err_msg)
728        })?;
729
730        let ctx = match ctx {
731            Context::ImageToImage(ctx) => ctx,
732            _ => {
733                let err_msg = "Fail to get the context for the image-to-image task";
734
735                #[cfg(feature = "logging")]
736                error!(target: "stdout", "{}", err_msg);
737
738                return Err(LlamaCoreError::InitContext(err_msg.into()));
739            }
740        };
741
742        #[cfg(feature = "logging")]
743        info!(target: "stdout", "sd image_to_image context: {:?}", &ctx);
744
745        SD_IMAGE_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
746            let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_IMAGE_TO_IMAGE` has already been initialized";
747
748            #[cfg(feature = "logging")]
749            error!(target: "stdout", "{}", err_msg);
750
751            LlamaCoreError::InitContext(err_msg.into())
752        })?;
753
754        #[cfg(feature = "logging")]
755        info!(target: "stdout", "The stable diffusion image-to-image context has been initialized");
756    }
757
758    Ok(())
759}
760
761/// Initialize the stable-diffusion context with the given standalone diffusion model
762///
763/// # Arguments
764///
765/// * `model_file` - Path to the standalone diffusion model file.
766///
767/// * `vae` - Path to the VAE model file.
768///
769/// * `clip_l` - Path to the CLIP model file.
770///
771/// * `t5xxl` - Path to the T5-XXL model file.
772///
773/// * `lora_model_dir` - Path to the Lora model directory.
774///
775/// * `controlnet_path` - Path to the controlnet model file.
776///
777/// * `controlnet_on_cpu` - Whether to run the controlnet on CPU.
778///
779/// * `clip_on_cpu` - Whether to run the CLIP on CPU.
780///
781/// * `vae_on_cpu` - Whether to run the VAE on CPU.
782///
783/// * `n_threads` - Number of threads to use.
784///
785/// * `task` - The task type to perform.
786#[allow(clippy::too_many_arguments)]
787pub fn init_sd_context_with_standalone_model(
788    model_file: impl AsRef<str>,
789    vae: impl AsRef<str>,
790    clip_l: impl AsRef<str>,
791    t5xxl: impl AsRef<str>,
792    lora_model_dir: Option<&str>,
793    controlnet_path: Option<&str>,
794    controlnet_on_cpu: bool,
795    clip_on_cpu: bool,
796    vae_on_cpu: bool,
797    n_threads: i32,
798    task: StableDiffusionTask,
799) -> Result<(), LlamaCoreError> {
800    #[cfg(feature = "logging")]
801    info!(target: "stdout", "Initializing the stable diffusion context with the standalone diffusion model");
802
803    let control_net_on_cpu = match controlnet_path {
804        Some(path) if !path.is_empty() => controlnet_on_cpu,
805        _ => false,
806    };
807
808    // create the stable diffusion context for the text-to-image task
809    if task == StableDiffusionTask::Full || task == StableDiffusionTask::TextToImage {
810        let sd = SDBuidler::new_with_standalone_model(Task::TextToImage, model_file.as_ref())
811            .map_err(|e| {
812                let err_msg = format!(
813                    "Failed to initialize the stable diffusion context. Reason: {}",
814                    e
815                );
816
817                #[cfg(feature = "logging")]
818                error!(target: "stdout", "{}", err_msg);
819
820                LlamaCoreError::InitContext(err_msg)
821            })?
822            .with_vae_path(vae.as_ref())
823            .map_err(|e| {
824                let err_msg = format!(
825                    "Failed to initialize the stable diffusion context. Reason: {}",
826                    e
827                );
828
829                #[cfg(feature = "logging")]
830                error!(target: "stdout", "{}", err_msg);
831
832                LlamaCoreError::InitContext(err_msg)
833            })?
834            .with_clip_l_path(clip_l.as_ref())
835            .map_err(|e| {
836                let err_msg = format!(
837                    "Failed to initialize the stable diffusion context. Reason: {}",
838                    e
839                );
840
841                #[cfg(feature = "logging")]
842                error!(target: "stdout", "{}", err_msg);
843
844                LlamaCoreError::InitContext(err_msg)
845            })?
846            .with_t5xxl_path(t5xxl.as_ref())
847            .map_err(|e| {
848                let err_msg = format!(
849                    "Failed to initialize the stable diffusion context. Reason: {}",
850                    e
851                );
852
853                #[cfg(feature = "logging")]
854                error!(target: "stdout", "{}", err_msg);
855
856                LlamaCoreError::InitContext(err_msg)
857            })?
858            .with_lora_model_dir(lora_model_dir.unwrap_or_default())
859            .map_err(|e| {
860                let err_msg = format!(
861                    "Failed to initialize the stable diffusion context. Reason: {}",
862                    e
863                );
864
865                #[cfg(feature = "logging")]
866                error!(target: "stdout", "{}", err_msg);
867
868                LlamaCoreError::InitContext(err_msg)
869            })?
870            .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
871            .map_err(|e| {
872                let err_msg = format!(
873                    "Failed to initialize the stable diffusion context. Reason: {}",
874                    e
875                );
876
877                #[cfg(feature = "logging")]
878                error!(target: "stdout", "{}", err_msg);
879
880                LlamaCoreError::InitContext(err_msg)
881            })?
882            .clip_on_cpu(clip_on_cpu)
883            .vae_on_cpu(vae_on_cpu)
884            .with_n_threads(n_threads)
885            .build();
886
887        #[cfg(feature = "logging")]
888        info!(target: "stdout", "sd: {:?}", &sd);
889
890        let ctx = sd.create_context().map_err(|e| {
891            let err_msg = format!("Fail to create the context. {}", e);
892
893            #[cfg(feature = "logging")]
894            error!(target: "stdout", "{}", &err_msg);
895
896            LlamaCoreError::InitContext(err_msg)
897        })?;
898
899        let ctx = match ctx {
900            Context::TextToImage(ctx) => ctx,
901            _ => {
902                let err_msg = "Fail to get the context for the text-to-image task";
903
904                #[cfg(feature = "logging")]
905                error!(target: "stdout", "{}", err_msg);
906
907                return Err(LlamaCoreError::InitContext(err_msg.into()));
908            }
909        };
910
911        #[cfg(feature = "logging")]
912        info!(target: "stdout", "sd text_to_image context: {:?}", &ctx);
913
914        SD_TEXT_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
915            let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_TEXT_TO_IMAGE` has already been initialized";
916
917            #[cfg(feature = "logging")]
918            error!(target: "stdout", "{}", err_msg);
919
920            LlamaCoreError::InitContext(err_msg.into())
921        })?;
922
923        #[cfg(feature = "logging")]
924        info!(target: "stdout", "The stable diffusion text-to-image context has been initialized");
925    }
926
927    // create the stable diffusion context for the image-to-image task
928    if task == StableDiffusionTask::Full || task == StableDiffusionTask::ImageToImage {
929        let sd = SDBuidler::new_with_standalone_model(Task::ImageToImage, model_file.as_ref())
930            .map_err(|e| {
931                let err_msg = format!(
932                    "Failed to initialize the stable diffusion context. Reason: {}",
933                    e
934                );
935
936                #[cfg(feature = "logging")]
937                error!(target: "stdout", "{}", err_msg);
938
939                LlamaCoreError::InitContext(err_msg)
940            })?
941            .with_vae_path(vae.as_ref())
942            .map_err(|e| {
943                let err_msg = format!(
944                    "Failed to initialize the stable diffusion context. Reason: {}",
945                    e
946                );
947
948                #[cfg(feature = "logging")]
949                error!(target: "stdout", "{}", err_msg);
950
951                LlamaCoreError::InitContext(err_msg)
952            })?
953            .with_clip_l_path(clip_l.as_ref())
954            .map_err(|e| {
955                let err_msg = format!(
956                    "Failed to initialize the stable diffusion context. Reason: {}",
957                    e
958                );
959
960                #[cfg(feature = "logging")]
961                error!(target: "stdout", "{}", err_msg);
962
963                LlamaCoreError::InitContext(err_msg)
964            })?
965            .with_t5xxl_path(t5xxl.as_ref())
966            .map_err(|e| {
967                let err_msg = format!(
968                    "Failed to initialize the stable diffusion context. Reason: {}",
969                    e
970                );
971
972                #[cfg(feature = "logging")]
973                error!(target: "stdout", "{}", err_msg);
974
975                LlamaCoreError::InitContext(err_msg)
976            })?
977            .with_lora_model_dir(lora_model_dir.unwrap_or_default())
978            .map_err(|e| {
979                let err_msg = format!(
980                    "Failed to initialize the stable diffusion context. Reason: {}",
981                    e
982                );
983
984                #[cfg(feature = "logging")]
985                error!(target: "stdout", "{}", err_msg);
986
987                LlamaCoreError::InitContext(err_msg)
988            })?
989            .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
990            .map_err(|e| {
991                let err_msg = format!(
992                    "Failed to initialize the stable diffusion context. Reason: {}",
993                    e
994                );
995
996                #[cfg(feature = "logging")]
997                error!(target: "stdout", "{}", err_msg);
998
999                LlamaCoreError::InitContext(err_msg)
1000            })?
1001            .clip_on_cpu(clip_on_cpu)
1002            .vae_on_cpu(vae_on_cpu)
1003            .with_n_threads(n_threads)
1004            .build();
1005
1006        #[cfg(feature = "logging")]
1007        info!(target: "stdout", "sd: {:?}", &sd);
1008
1009        let ctx = sd.create_context().map_err(|e| {
1010            let err_msg = format!("Fail to create the context. {}", e);
1011
1012            #[cfg(feature = "logging")]
1013            error!(target: "stdout", "{}", &err_msg);
1014
1015            LlamaCoreError::InitContext(err_msg)
1016        })?;
1017
1018        let ctx = match ctx {
1019            Context::ImageToImage(ctx) => ctx,
1020            _ => {
1021                let err_msg = "Fail to get the context for the image-to-image task";
1022
1023                #[cfg(feature = "logging")]
1024                error!(target: "stdout", "{}", err_msg);
1025
1026                return Err(LlamaCoreError::InitContext(err_msg.into()));
1027            }
1028        };
1029
1030        #[cfg(feature = "logging")]
1031        info!(target: "stdout", "sd image_to_image context: {:?}", &ctx);
1032
1033        SD_IMAGE_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
1034        let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_IMAGE_TO_IMAGE` has already been initialized";
1035
1036        #[cfg(feature = "logging")]
1037        error!(target: "stdout", "{}", err_msg);
1038
1039        LlamaCoreError::InitContext(err_msg.into())
1040    })?;
1041
1042        #[cfg(feature = "logging")]
1043        info!(target: "stdout", "The stable diffusion image-to-image context has been initialized");
1044    }
1045
1046    Ok(())
1047}
1048
1049/// The task type of the stable diffusion context
1050#[derive(Clone, Debug, Copy, PartialEq, Eq)]
1051pub enum StableDiffusionTask {
1052    /// `text_to_image` context
1053    TextToImage,
1054    /// `image_to_image` context
1055    ImageToImage,
1056    /// Both `text_to_image` and `image_to_image` contexts
1057    Full,
1058}
1059
1060/// Initialize the whisper context
1061#[cfg(feature = "whisper")]
1062pub fn init_whisper_context(whisper_metadata: &WhisperMetadata) -> Result<(), LlamaCoreError> {
1063    // create and initialize the audio context
1064    let graph = GraphBuilder::new(EngineType::Whisper)?
1065        .with_config(whisper_metadata.clone())?
1066        .use_cpu()
1067        .build_from_files([&whisper_metadata.model_path])?;
1068
1069    match AUDIO_GRAPH.get() {
1070        Some(mutex_graph) => {
1071            #[cfg(feature = "logging")]
1072            info!(target: "stdout", "Re-initialize the audio context");
1073
1074            match mutex_graph.lock() {
1075                Ok(mut locked_graph) => *locked_graph = graph,
1076                Err(e) => {
1077                    let err_msg = format!("Failed to lock the graph. Reason: {}", e);
1078
1079                    #[cfg(feature = "logging")]
1080                    error!(target: "stdout", "{}", err_msg);
1081
1082                    return Err(LlamaCoreError::InitContext(err_msg));
1083                }
1084            }
1085        }
1086        None => {
1087            #[cfg(feature = "logging")]
1088            info!(target: "stdout", "Initialize the audio context");
1089
1090            AUDIO_GRAPH.set(Mutex::new(graph)).map_err(|_| {
1091                let err_msg = "Failed to initialize the audio context. Reason: The `AUDIO_GRAPH` has already been initialized";
1092
1093                #[cfg(feature = "logging")]
1094                error!(target: "stdout", "{}", err_msg);
1095
1096                LlamaCoreError::InitContext(err_msg.into())
1097            })?;
1098        }
1099    }
1100
1101    #[cfg(feature = "logging")]
1102    info!(target: "stdout", "The audio context has been initialized");
1103
1104    Ok(())
1105}
1106
1107/// Initialize the piper context
1108///
1109/// # Arguments
1110///
1111/// * `voice_model` - Path to the voice model file.
1112///
1113/// * `voice_config` - Path to the voice config file.
1114///
1115/// * `espeak_ng_data` - Path to the espeak-ng data directory.
1116///
1117pub fn init_piper_context(
1118    piper_metadata: &PiperMetadata,
1119    voice_model: impl AsRef<Path>,
1120    voice_config: impl AsRef<Path>,
1121    espeak_ng_data: impl AsRef<Path>,
1122) -> Result<(), LlamaCoreError> {
1123    #[cfg(feature = "logging")]
1124    info!(target: "stdout", "Initializing the piper context");
1125
1126    let config = serde_json::json!({
1127        "model": voice_model.as_ref().to_owned(),
1128        "config": voice_config.as_ref().to_owned(),
1129        "espeak_data": espeak_ng_data.as_ref().to_owned(),
1130    });
1131
1132    // create and initialize the audio context
1133    let graph = GraphBuilder::new(EngineType::Piper)?
1134        .with_config(piper_metadata.clone())?
1135        .use_cpu()
1136        .build_from_buffer([config.to_string()])?;
1137
1138    PIPER_GRAPH.set(Mutex::new(graph)).map_err(|_| {
1139            let err_msg = "Failed to initialize the piper context. Reason: The `PIPER_GRAPH` has already been initialized";
1140
1141            #[cfg(feature = "logging")]
1142            error!(target: "stdout", "{}", err_msg);
1143
1144            LlamaCoreError::InitContext(err_msg.into())
1145        })?;
1146
1147    #[cfg(feature = "logging")]
1148    info!(target: "stdout", "The piper context has been initialized");
1149
1150    Ok(())
1151}