llama_core/
lib.rs

1//! Llama Core, abbreviated as `llama-core`, defines a set of APIs. Developers can utilize these APIs to build applications based on large models, such as chatbots, RAG, and more.
2
3#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
4
5#[cfg(feature = "logging")]
6#[macro_use]
7extern crate log;
8
9pub mod audio;
10pub mod chat;
11pub mod completions;
12pub mod embeddings;
13pub mod error;
14pub mod files;
15pub mod graph;
16pub mod images;
17pub mod metadata;
18pub mod models;
19pub mod tts;
20pub mod utils;
21
22pub use error::LlamaCoreError;
23pub use graph::{EngineType, Graph, GraphBuilder};
24#[cfg(feature = "whisper")]
25use metadata::whisper::WhisperMetadata;
26pub use metadata::{
27    ggml::{GgmlMetadata, GgmlTtsMetadata},
28    piper::PiperMetadata,
29    BaseMetadata,
30};
31use once_cell::sync::OnceCell;
32use std::{
33    collections::HashMap,
34    path::Path,
35    sync::{Mutex, RwLock},
36};
37use utils::{get_output_buffer, RunningMode};
38use wasmedge_stable_diffusion::*;
39
40// key: model_name, value: Graph
41pub(crate) static CHAT_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlMetadata>>>> =
42    OnceCell::new();
43// key: model_name, value: Graph
44pub(crate) static EMBEDDING_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlMetadata>>>> =
45    OnceCell::new();
46// key: model_name, value: Graph
47pub(crate) static TTS_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlTtsMetadata>>>> =
48    OnceCell::new();
49// cache bytes for decoding utf8
50pub(crate) static CACHED_UTF8_ENCODINGS: OnceCell<Mutex<Vec<u8>>> = OnceCell::new();
51// running mode
52pub(crate) static RUNNING_MODE: OnceCell<RwLock<RunningMode>> = OnceCell::new();
53// stable diffusion context for the text-to-image task
54pub(crate) static SD_TEXT_TO_IMAGE: OnceCell<Mutex<TextToImage>> = OnceCell::new();
55// stable diffusion context for the image-to-image task
56pub(crate) static SD_IMAGE_TO_IMAGE: OnceCell<Mutex<ImageToImage>> = OnceCell::new();
57// context for the audio task
58#[cfg(feature = "whisper")]
59pub(crate) static AUDIO_GRAPH: OnceCell<Mutex<Graph<WhisperMetadata>>> = OnceCell::new();
60// context for the piper task
61pub(crate) static PIPER_GRAPH: OnceCell<Mutex<Graph<PiperMetadata>>> = OnceCell::new();
62
63pub(crate) const MAX_BUFFER_SIZE: usize = 2usize.pow(14) * 15 + 128;
64pub(crate) const OUTPUT_TENSOR: usize = 0;
65const PLUGIN_VERSION: usize = 1;
66
67/// The directory for storing the archives in wasm virtual file system.
68pub const ARCHIVES_DIR: &str = "archives";
69
70/// Initialize the ggml context
71pub fn init_ggml_chat_context(metadata_for_chats: &[GgmlMetadata]) -> Result<(), LlamaCoreError> {
72    #[cfg(feature = "logging")]
73    info!(target: "stdout", "Initializing the core context");
74
75    if metadata_for_chats.is_empty() {
76        let err_msg = "The metadata for chat models is empty";
77
78        #[cfg(feature = "logging")]
79        error!(target: "stdout", "{err_msg}");
80
81        return Err(LlamaCoreError::InitContext(err_msg.into()));
82    }
83
84    let mut chat_graphs = HashMap::new();
85    for metadata in metadata_for_chats {
86        let graph = Graph::new(metadata.clone())?;
87
88        chat_graphs.insert(graph.name().to_string(), graph);
89    }
90    CHAT_GRAPHS.set(Mutex::new(chat_graphs)).map_err(|_| {
91            let err_msg = "Failed to initialize the core context. Reason: The `CHAT_GRAPHS` has already been initialized";
92
93            #[cfg(feature = "logging")]
94            error!(target: "stdout", "{err_msg}");
95
96            LlamaCoreError::InitContext(err_msg.into())
97        })?;
98
99    // set running mode
100    let running_mode = RunningMode::CHAT;
101    match RUNNING_MODE.get() {
102        Some(mode) => {
103            let mut mode = mode.write().unwrap();
104            *mode |= running_mode;
105        }
106        None => {
107            RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
108                let err_msg = "Failed to initialize the chat context. Reason: The `RUNNING_MODE` has already been initialized";
109
110                #[cfg(feature = "logging")]
111                error!(target: "stdout", "{err_msg}");
112
113                LlamaCoreError::InitContext(err_msg.into())
114            })?;
115        }
116    }
117
118    Ok(())
119}
120
121/// Initialize the ggml context
122pub fn init_ggml_embeddings_context(
123    metadata_for_embeddings: &[GgmlMetadata],
124) -> Result<(), LlamaCoreError> {
125    #[cfg(feature = "logging")]
126    info!(target: "stdout", "Initializing the embeddings context");
127
128    if metadata_for_embeddings.is_empty() {
129        let err_msg = "The metadata for chat models is empty";
130
131        #[cfg(feature = "logging")]
132        error!(target: "stdout", "{err_msg}");
133
134        return Err(LlamaCoreError::InitContext(err_msg.into()));
135    }
136
137    let mut embedding_graphs = HashMap::new();
138    for metadata in metadata_for_embeddings {
139        let graph = Graph::new(metadata.clone())?;
140
141        embedding_graphs.insert(graph.name().to_string(), graph);
142    }
143    EMBEDDING_GRAPHS
144            .set(Mutex::new(embedding_graphs))
145            .map_err(|_| {
146                let err_msg = "Failed to initialize the core context. Reason: The `EMBEDDING_GRAPHS` has already been initialized";
147
148                #[cfg(feature = "logging")]
149                error!(target: "stdout", "{err_msg}");
150
151                LlamaCoreError::InitContext(err_msg.into())
152            })?;
153
154    // set running mode
155    let running_mode = RunningMode::EMBEDDINGS;
156    match RUNNING_MODE.get() {
157        Some(mode) => {
158            let mut mode = mode.write().unwrap();
159            *mode |= running_mode;
160        }
161        None => {
162            RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
163                let err_msg = "Failed to initialize the embeddings context. Reason: The `RUNNING_MODE` has already been initialized";
164
165                #[cfg(feature = "logging")]
166                error!(target: "stdout", "{err_msg}");
167
168                LlamaCoreError::InitContext(err_msg.into())
169            })?;
170        }
171    }
172
173    Ok(())
174}
175
176/// Initialize the ggml context for TTS scenarios.
177pub fn init_ggml_tts_context(metadata_for_tts: &[GgmlTtsMetadata]) -> Result<(), LlamaCoreError> {
178    #[cfg(feature = "logging")]
179    info!(target: "stdout", "Initializing the TTS context");
180
181    if metadata_for_tts.is_empty() {
182        let err_msg = "The metadata for tts models is empty";
183
184        #[cfg(feature = "logging")]
185        error!(target: "stdout", "{err_msg}");
186
187        return Err(LlamaCoreError::InitContext(err_msg.into()));
188    }
189
190    let mut tts_graphs = HashMap::new();
191    for metadata in metadata_for_tts {
192        let graph = Graph::new(metadata.clone())?;
193
194        tts_graphs.insert(graph.name().to_string(), graph);
195    }
196    TTS_GRAPHS.set(Mutex::new(tts_graphs)).map_err(|_| {
197        let err_msg = "Failed to initialize the core context. Reason: The `TTS_GRAPHS` has already been initialized";
198
199        #[cfg(feature = "logging")]
200        error!(target: "stdout", "{err_msg}");
201
202        LlamaCoreError::InitContext(err_msg.into())
203    })?;
204
205    // set running mode
206    let running_mode = RunningMode::TTS;
207    match RUNNING_MODE.get() {
208        Some(mode) => {
209            let mut mode = mode.write().unwrap();
210            *mode |= running_mode;
211        }
212        None => {
213            RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
214                let err_msg = "Failed to initialize the embeddings context. Reason: The `RUNNING_MODE` has already been initialized";
215
216                #[cfg(feature = "logging")]
217                error!(target: "stdout", "{err_msg}");
218
219                LlamaCoreError::InitContext(err_msg.into())
220            })?;
221        }
222    }
223
224    Ok(())
225}
226
227/// Get the plugin info
228///
229/// Note that it is required to call `init_core_context` before calling this function.
230pub fn get_plugin_info() -> Result<PluginInfo, LlamaCoreError> {
231    #[cfg(feature = "logging")]
232    debug!(target: "stdout", "Getting the plugin info");
233
234    let running_mode = running_mode()?;
235
236    if running_mode.contains(RunningMode::CHAT) || running_mode.contains(RunningMode::RAG) {
237        let chat_graphs = match CHAT_GRAPHS.get() {
238            Some(chat_graphs) => chat_graphs,
239            None => {
240                let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
241
242                #[cfg(feature = "logging")]
243                error!(target: "stdout", "{err_msg}");
244
245                return Err(LlamaCoreError::Operation(err_msg.into()));
246            }
247        };
248
249        let chat_graphs = chat_graphs.lock().map_err(|e| {
250            let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
251
252            #[cfg(feature = "logging")]
253            error!(target: "stdout", "{}", &err_msg);
254
255            LlamaCoreError::Operation(err_msg)
256        })?;
257
258        let graph = match chat_graphs.values().next() {
259            Some(graph) => graph,
260            None => {
261                let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
262
263                #[cfg(feature = "logging")]
264                error!(target: "stdout", "{err_msg}");
265
266                return Err(LlamaCoreError::Operation(err_msg.into()));
267            }
268        };
269
270        get_plugin_info_by_graph(graph)
271    } else if running_mode.contains(RunningMode::EMBEDDINGS) {
272        let embedding_graphs = match EMBEDDING_GRAPHS.get() {
273            Some(embedding_graphs) => embedding_graphs,
274            None => {
275                let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
276
277                #[cfg(feature = "logging")]
278                error!(target: "stdout", "{err_msg}");
279
280                return Err(LlamaCoreError::Operation(err_msg.into()));
281            }
282        };
283
284        let embedding_graphs = embedding_graphs.lock().map_err(|e| {
285            let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {e}");
286
287            #[cfg(feature = "logging")]
288            error!(target: "stdout", "{}", &err_msg);
289
290            LlamaCoreError::Operation(err_msg)
291        })?;
292
293        let graph = match embedding_graphs.values().next() {
294            Some(graph) => graph,
295            None => {
296                let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
297
298                #[cfg(feature = "logging")]
299                error!(target: "stdout", "{err_msg}");
300
301                return Err(LlamaCoreError::Operation(err_msg.into()));
302            }
303        };
304
305        get_plugin_info_by_graph(graph)
306    } else if running_mode.contains(RunningMode::TTS) {
307        let tts_graphs = match TTS_GRAPHS.get() {
308            Some(tts_graphs) => tts_graphs,
309            None => {
310                let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
311
312                #[cfg(feature = "logging")]
313                error!(target: "stdout", "{err_msg}");
314
315                return Err(LlamaCoreError::Operation(err_msg.into()));
316            }
317        };
318
319        let tts_graphs = tts_graphs.lock().map_err(|e| {
320            let err_msg = format!("Fail to acquire the lock of `TTS_GRAPHS`. {e}");
321
322            #[cfg(feature = "logging")]
323            error!(target: "stdout", "{}", &err_msg);
324
325            LlamaCoreError::Operation(err_msg)
326        })?;
327
328        let graph = match tts_graphs.values().next() {
329            Some(graph) => graph,
330            None => {
331                let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
332
333                #[cfg(feature = "logging")]
334                error!(target: "stdout", "{err_msg}");
335
336                return Err(LlamaCoreError::Operation(err_msg.into()));
337            }
338        };
339
340        get_plugin_info_by_graph(graph)
341    } else {
342        let err_msg = "RUNNING_MODE is not set";
343
344        #[cfg(feature = "logging")]
345        error!(target: "stdout", "{err_msg}");
346
347        Err(LlamaCoreError::Operation(err_msg.into()))
348    }
349}
350
351fn get_plugin_info_by_graph<M: BaseMetadata + serde::Serialize + Clone + Default>(
352    graph: &Graph<M>,
353) -> Result<PluginInfo, LlamaCoreError> {
354    #[cfg(feature = "logging")]
355    debug!(target: "stdout", "Getting the plugin info by the graph named {}", graph.name());
356
357    // get the plugin metadata
358    let output_buffer = get_output_buffer(graph, PLUGIN_VERSION)?;
359    let metadata: serde_json::Value = serde_json::from_slice(&output_buffer[..]).map_err(|e| {
360        let err_msg = format!("Fail to deserialize the plugin metadata. {e}");
361
362        #[cfg(feature = "logging")]
363        error!(target: "stdout", "{}", &err_msg);
364
365        LlamaCoreError::Operation(err_msg)
366    })?;
367
368    // get build number of the plugin
369    let plugin_build_number = match metadata.get("llama_build_number") {
370        Some(value) => match value.as_u64() {
371            Some(number) => number,
372            None => {
373                let err_msg = "Failed to convert the build number of the plugin to u64";
374
375                #[cfg(feature = "logging")]
376                error!(target: "stdout", "{err_msg}");
377
378                return Err(LlamaCoreError::Operation(err_msg.into()));
379            }
380        },
381        None => {
382            let err_msg = "Metadata does not have the field `llama_build_number`.";
383
384            #[cfg(feature = "logging")]
385            error!(target: "stdout", "{err_msg}");
386
387            return Err(LlamaCoreError::Operation(err_msg.into()));
388        }
389    };
390
391    // get commit id of the plugin
392    let plugin_commit = match metadata.get("llama_commit") {
393        Some(value) => match value.as_str() {
394            Some(commit) => commit,
395            None => {
396                let err_msg = "Failed to convert the commit id of the plugin to string";
397
398                #[cfg(feature = "logging")]
399                error!(target: "stdout", "{err_msg}");
400
401                return Err(LlamaCoreError::Operation(err_msg.into()));
402            }
403        },
404        None => {
405            let err_msg = "Metadata does not have the field `llama_commit`.";
406
407            #[cfg(feature = "logging")]
408            error!(target: "stdout", "{err_msg}");
409
410            return Err(LlamaCoreError::Operation(err_msg.into()));
411        }
412    };
413
414    #[cfg(feature = "logging")]
415    debug!(target: "stdout", "Plugin info: b{plugin_build_number}(commit {plugin_commit})");
416
417    Ok(PluginInfo {
418        build_number: plugin_build_number,
419        commit_id: plugin_commit.to_string(),
420    })
421}
422
423/// Version info of the `wasi-nn_ggml` plugin, including the build number and the commit id.
424#[derive(Debug, Clone)]
425pub struct PluginInfo {
426    pub build_number: u64,
427    pub commit_id: String,
428}
429impl std::fmt::Display for PluginInfo {
430    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
431        write!(
432            f,
433            "wasinn-ggml plugin: b{}(commit {})",
434            self.build_number, self.commit_id
435        )
436    }
437}
438
439/// Return the current running mode.
440pub fn running_mode() -> Result<RunningMode, LlamaCoreError> {
441    #[cfg(feature = "logging")]
442    debug!(target: "stdout", "Get the running mode.");
443
444    match RUNNING_MODE.get() {
445        Some(mode) => match mode.read() {
446            Ok(mode) => Ok(*mode),
447            Err(e) => {
448                let err_msg = format!("Fail to get the underlying value of `RUNNING_MODE`. {e}");
449
450                #[cfg(feature = "logging")]
451                error!(target: "stdout", "{err_msg}");
452
453                Err(LlamaCoreError::Operation(err_msg))
454            }
455        },
456        None => {
457            let err_msg = "Fail to get the underlying value of `RUNNING_MODE`.";
458
459            #[cfg(feature = "logging")]
460            error!(target: "stdout", "{err_msg}");
461
462            Err(LlamaCoreError::Operation(err_msg.into()))
463        }
464    }
465}
466
467/// Initialize the stable-diffusion context with the given full diffusion model
468///
469/// # Arguments
470///
471/// * `model_file` - Path to the stable diffusion model file.
472///
473/// * `lora_model_dir` - Path to the Lora model directory.
474///
475/// * `controlnet_path` - Path to the controlnet model file.
476///
477/// * `controlnet_on_cpu` - Whether to run the controlnet on CPU.
478///
479/// * `clip_on_cpu` - Whether to run the CLIP on CPU.
480///
481/// * `vae_on_cpu` - Whether to run the VAE on CPU.
482///
483/// * `n_threads` - Number of threads to use.
484///
485/// * `task` - The task type to perform.
486#[allow(clippy::too_many_arguments)]
487pub fn init_sd_context_with_full_model(
488    model_file: impl AsRef<str>,
489    lora_model_dir: Option<&str>,
490    controlnet_path: Option<&str>,
491    controlnet_on_cpu: bool,
492    clip_on_cpu: bool,
493    vae_on_cpu: bool,
494    n_threads: i32,
495    task: StableDiffusionTask,
496) -> Result<(), LlamaCoreError> {
497    #[cfg(feature = "logging")]
498    info!(target: "stdout", "Initializing the stable diffusion context with the full model");
499
500    let control_net_on_cpu = match controlnet_path {
501        Some(path) if !path.is_empty() => controlnet_on_cpu,
502        _ => false,
503    };
504
505    // create the stable diffusion context for the text-to-image task
506    if task == StableDiffusionTask::Full || task == StableDiffusionTask::TextToImage {
507        let sd = SDBuidler::new(Task::TextToImage, model_file.as_ref())
508            .map_err(|e| {
509                let err_msg =
510                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
511
512                #[cfg(feature = "logging")]
513                error!(target: "stdout", "{err_msg}");
514
515                LlamaCoreError::InitContext(err_msg)
516            })?
517            .with_lora_model_dir(lora_model_dir.unwrap_or_default())
518            .map_err(|e| {
519                let err_msg =
520                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
521
522                #[cfg(feature = "logging")]
523                error!(target: "stdout", "{err_msg}");
524
525                LlamaCoreError::InitContext(err_msg)
526            })?
527            .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
528            .map_err(|e| {
529                let err_msg =
530                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
531
532                #[cfg(feature = "logging")]
533                error!(target: "stdout", "{err_msg}");
534
535                LlamaCoreError::InitContext(err_msg)
536            })?
537            .clip_on_cpu(clip_on_cpu)
538            .vae_on_cpu(vae_on_cpu)
539            .with_n_threads(n_threads)
540            .build();
541
542        #[cfg(feature = "logging")]
543        info!(target: "stdout", "sd: {:?}", &sd);
544
545        let ctx = sd.create_context().map_err(|e| {
546            let err_msg = format!("Fail to create the context. {e}");
547
548            #[cfg(feature = "logging")]
549            error!(target: "stdout", "{}", &err_msg);
550
551            LlamaCoreError::InitContext(err_msg)
552        })?;
553
554        let ctx = match ctx {
555            Context::TextToImage(ctx) => ctx,
556            _ => {
557                let err_msg = "Fail to get the context for the text-to-image task";
558
559                #[cfg(feature = "logging")]
560                error!(target: "stdout", "{err_msg}");
561
562                return Err(LlamaCoreError::InitContext(err_msg.into()));
563            }
564        };
565
566        #[cfg(feature = "logging")]
567        info!(target: "stdout", "sd text_to_image context: {:?}", &ctx);
568
569        SD_TEXT_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
570        let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_TEXT_TO_IMAGE` has already been initialized";
571
572        #[cfg(feature = "logging")]
573        error!(target: "stdout", "{err_msg}");
574
575        LlamaCoreError::InitContext(err_msg.into())
576    })?;
577
578        #[cfg(feature = "logging")]
579        info!(target: "stdout", "The stable diffusion text-to-image context has been initialized");
580    }
581
582    // create the stable diffusion context for the image-to-image task
583    if task == StableDiffusionTask::Full || task == StableDiffusionTask::ImageToImage {
584        let sd = SDBuidler::new(Task::ImageToImage, model_file.as_ref())
585            .map_err(|e| {
586                let err_msg =
587                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
588
589                #[cfg(feature = "logging")]
590                error!(target: "stdout", "{err_msg}");
591
592                LlamaCoreError::InitContext(err_msg)
593            })?
594            .with_lora_model_dir(lora_model_dir.unwrap_or_default())
595            .map_err(|e| {
596                let err_msg =
597                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
598
599                #[cfg(feature = "logging")]
600                error!(target: "stdout", "{err_msg}");
601
602                LlamaCoreError::InitContext(err_msg)
603            })?
604            .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
605            .map_err(|e| {
606                let err_msg =
607                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
608
609                #[cfg(feature = "logging")]
610                error!(target: "stdout", "{err_msg}");
611
612                LlamaCoreError::InitContext(err_msg)
613            })?
614            .clip_on_cpu(clip_on_cpu)
615            .vae_on_cpu(vae_on_cpu)
616            .with_n_threads(n_threads)
617            .build();
618
619        #[cfg(feature = "logging")]
620        info!(target: "stdout", "sd: {:?}", &sd);
621
622        let ctx = sd.create_context().map_err(|e| {
623            let err_msg = format!("Fail to create the context. {e}");
624
625            #[cfg(feature = "logging")]
626            error!(target: "stdout", "{}", &err_msg);
627
628            LlamaCoreError::InitContext(err_msg)
629        })?;
630
631        let ctx = match ctx {
632            Context::ImageToImage(ctx) => ctx,
633            _ => {
634                let err_msg = "Fail to get the context for the image-to-image task";
635
636                #[cfg(feature = "logging")]
637                error!(target: "stdout", "{err_msg}");
638
639                return Err(LlamaCoreError::InitContext(err_msg.into()));
640            }
641        };
642
643        #[cfg(feature = "logging")]
644        info!(target: "stdout", "sd image_to_image context: {:?}", &ctx);
645
646        SD_IMAGE_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
647            let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_IMAGE_TO_IMAGE` has already been initialized";
648
649            #[cfg(feature = "logging")]
650            error!(target: "stdout", "{err_msg}");
651
652            LlamaCoreError::InitContext(err_msg.into())
653        })?;
654
655        #[cfg(feature = "logging")]
656        info!(target: "stdout", "The stable diffusion image-to-image context has been initialized");
657    }
658
659    Ok(())
660}
661
662/// Initialize the stable-diffusion context with the given standalone diffusion model
663///
664/// # Arguments
665///
666/// * `model_file` - Path to the standalone diffusion model file.
667///
668/// * `vae` - Path to the VAE model file.
669///
670/// * `clip_l` - Path to the CLIP model file.
671///
672/// * `t5xxl` - Path to the T5-XXL model file.
673///
674/// * `lora_model_dir` - Path to the Lora model directory.
675///
676/// * `controlnet_path` - Path to the controlnet model file.
677///
678/// * `controlnet_on_cpu` - Whether to run the controlnet on CPU.
679///
680/// * `clip_on_cpu` - Whether to run the CLIP on CPU.
681///
682/// * `vae_on_cpu` - Whether to run the VAE on CPU.
683///
684/// * `n_threads` - Number of threads to use.
685///
686/// * `task` - The task type to perform.
687#[allow(clippy::too_many_arguments)]
688pub fn init_sd_context_with_standalone_model(
689    model_file: impl AsRef<str>,
690    vae: impl AsRef<str>,
691    clip_l: impl AsRef<str>,
692    t5xxl: impl AsRef<str>,
693    lora_model_dir: Option<&str>,
694    controlnet_path: Option<&str>,
695    controlnet_on_cpu: bool,
696    clip_on_cpu: bool,
697    vae_on_cpu: bool,
698    n_threads: i32,
699    task: StableDiffusionTask,
700) -> Result<(), LlamaCoreError> {
701    #[cfg(feature = "logging")]
702    info!(target: "stdout", "Initializing the stable diffusion context with the standalone diffusion model");
703
704    let control_net_on_cpu = match controlnet_path {
705        Some(path) if !path.is_empty() => controlnet_on_cpu,
706        _ => false,
707    };
708
709    // create the stable diffusion context for the text-to-image task
710    if task == StableDiffusionTask::Full || task == StableDiffusionTask::TextToImage {
711        let sd = SDBuidler::new_with_standalone_model(Task::TextToImage, model_file.as_ref())
712            .map_err(|e| {
713                let err_msg =
714                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
715
716                #[cfg(feature = "logging")]
717                error!(target: "stdout", "{err_msg}");
718
719                LlamaCoreError::InitContext(err_msg)
720            })?
721            .with_vae_path(vae.as_ref())
722            .map_err(|e| {
723                let err_msg =
724                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
725
726                #[cfg(feature = "logging")]
727                error!(target: "stdout", "{err_msg}");
728
729                LlamaCoreError::InitContext(err_msg)
730            })?
731            .with_clip_l_path(clip_l.as_ref())
732            .map_err(|e| {
733                let err_msg =
734                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
735
736                #[cfg(feature = "logging")]
737                error!(target: "stdout", "{err_msg}");
738
739                LlamaCoreError::InitContext(err_msg)
740            })?
741            .with_t5xxl_path(t5xxl.as_ref())
742            .map_err(|e| {
743                let err_msg =
744                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
745
746                #[cfg(feature = "logging")]
747                error!(target: "stdout", "{err_msg}");
748
749                LlamaCoreError::InitContext(err_msg)
750            })?
751            .with_lora_model_dir(lora_model_dir.unwrap_or_default())
752            .map_err(|e| {
753                let err_msg =
754                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
755
756                #[cfg(feature = "logging")]
757                error!(target: "stdout", "{err_msg}");
758
759                LlamaCoreError::InitContext(err_msg)
760            })?
761            .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
762            .map_err(|e| {
763                let err_msg =
764                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
765
766                #[cfg(feature = "logging")]
767                error!(target: "stdout", "{err_msg}");
768
769                LlamaCoreError::InitContext(err_msg)
770            })?
771            .clip_on_cpu(clip_on_cpu)
772            .vae_on_cpu(vae_on_cpu)
773            .with_n_threads(n_threads)
774            .build();
775
776        #[cfg(feature = "logging")]
777        info!(target: "stdout", "sd: {:?}", &sd);
778
779        let ctx = sd.create_context().map_err(|e| {
780            let err_msg = format!("Fail to create the context. {e}");
781
782            #[cfg(feature = "logging")]
783            error!(target: "stdout", "{}", &err_msg);
784
785            LlamaCoreError::InitContext(err_msg)
786        })?;
787
788        let ctx = match ctx {
789            Context::TextToImage(ctx) => ctx,
790            _ => {
791                let err_msg = "Fail to get the context for the text-to-image task";
792
793                #[cfg(feature = "logging")]
794                error!(target: "stdout", "{err_msg}");
795
796                return Err(LlamaCoreError::InitContext(err_msg.into()));
797            }
798        };
799
800        #[cfg(feature = "logging")]
801        info!(target: "stdout", "sd text_to_image context: {:?}", &ctx);
802
803        SD_TEXT_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
804            let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_TEXT_TO_IMAGE` has already been initialized";
805
806            #[cfg(feature = "logging")]
807            error!(target: "stdout", "{err_msg}");
808
809            LlamaCoreError::InitContext(err_msg.into())
810        })?;
811
812        #[cfg(feature = "logging")]
813        info!(target: "stdout", "The stable diffusion text-to-image context has been initialized");
814    }
815
816    // create the stable diffusion context for the image-to-image task
817    if task == StableDiffusionTask::Full || task == StableDiffusionTask::ImageToImage {
818        let sd = SDBuidler::new_with_standalone_model(Task::ImageToImage, model_file.as_ref())
819            .map_err(|e| {
820                let err_msg =
821                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
822
823                #[cfg(feature = "logging")]
824                error!(target: "stdout", "{err_msg}");
825
826                LlamaCoreError::InitContext(err_msg)
827            })?
828            .with_vae_path(vae.as_ref())
829            .map_err(|e| {
830                let err_msg =
831                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
832
833                #[cfg(feature = "logging")]
834                error!(target: "stdout", "{err_msg}");
835
836                LlamaCoreError::InitContext(err_msg)
837            })?
838            .with_clip_l_path(clip_l.as_ref())
839            .map_err(|e| {
840                let err_msg =
841                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
842
843                #[cfg(feature = "logging")]
844                error!(target: "stdout", "{err_msg}");
845
846                LlamaCoreError::InitContext(err_msg)
847            })?
848            .with_t5xxl_path(t5xxl.as_ref())
849            .map_err(|e| {
850                let err_msg =
851                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
852
853                #[cfg(feature = "logging")]
854                error!(target: "stdout", "{err_msg}");
855
856                LlamaCoreError::InitContext(err_msg)
857            })?
858            .with_lora_model_dir(lora_model_dir.unwrap_or_default())
859            .map_err(|e| {
860                let err_msg =
861                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
862
863                #[cfg(feature = "logging")]
864                error!(target: "stdout", "{err_msg}");
865
866                LlamaCoreError::InitContext(err_msg)
867            })?
868            .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
869            .map_err(|e| {
870                let err_msg =
871                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
872
873                #[cfg(feature = "logging")]
874                error!(target: "stdout", "{err_msg}");
875
876                LlamaCoreError::InitContext(err_msg)
877            })?
878            .clip_on_cpu(clip_on_cpu)
879            .vae_on_cpu(vae_on_cpu)
880            .with_n_threads(n_threads)
881            .build();
882
883        #[cfg(feature = "logging")]
884        info!(target: "stdout", "sd: {:?}", &sd);
885
886        let ctx = sd.create_context().map_err(|e| {
887            let err_msg = format!("Fail to create the context. {e}");
888
889            #[cfg(feature = "logging")]
890            error!(target: "stdout", "{}", &err_msg);
891
892            LlamaCoreError::InitContext(err_msg)
893        })?;
894
895        let ctx = match ctx {
896            Context::ImageToImage(ctx) => ctx,
897            _ => {
898                let err_msg = "Fail to get the context for the image-to-image task";
899
900                #[cfg(feature = "logging")]
901                error!(target: "stdout", "{err_msg}");
902
903                return Err(LlamaCoreError::InitContext(err_msg.into()));
904            }
905        };
906
907        #[cfg(feature = "logging")]
908        info!(target: "stdout", "sd image_to_image context: {:?}", &ctx);
909
910        SD_IMAGE_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
911        let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_IMAGE_TO_IMAGE` has already been initialized";
912
913        #[cfg(feature = "logging")]
914        error!(target: "stdout", "{err_msg}");
915
916        LlamaCoreError::InitContext(err_msg.into())
917    })?;
918
919        #[cfg(feature = "logging")]
920        info!(target: "stdout", "The stable diffusion image-to-image context has been initialized");
921    }
922
923    Ok(())
924}
925
926/// The task type of the stable diffusion context
927#[derive(Clone, Debug, Copy, PartialEq, Eq)]
928pub enum StableDiffusionTask {
929    /// `text_to_image` context
930    TextToImage,
931    /// `image_to_image` context
932    ImageToImage,
933    /// Both `text_to_image` and `image_to_image` contexts
934    Full,
935}
936
937/// Initialize the whisper context
938#[cfg(feature = "whisper")]
939pub fn init_whisper_context(whisper_metadata: &WhisperMetadata) -> Result<(), LlamaCoreError> {
940    // create and initialize the audio context
941    let graph = GraphBuilder::new(EngineType::Whisper)?
942        .with_config(whisper_metadata.clone())?
943        .use_cpu()
944        .build_from_files([&whisper_metadata.model_path])?;
945
946    match AUDIO_GRAPH.get() {
947        Some(mutex_graph) => {
948            #[cfg(feature = "logging")]
949            info!(target: "stdout", "Re-initialize the audio context");
950
951            match mutex_graph.lock() {
952                Ok(mut locked_graph) => *locked_graph = graph,
953                Err(e) => {
954                    let err_msg = format!("Failed to lock the graph. Reason: {e}");
955
956                    #[cfg(feature = "logging")]
957                    error!(target: "stdout", "{err_msg}");
958
959                    return Err(LlamaCoreError::InitContext(err_msg));
960                }
961            }
962        }
963        None => {
964            #[cfg(feature = "logging")]
965            info!(target: "stdout", "Initialize the audio context");
966
967            AUDIO_GRAPH.set(Mutex::new(graph)).map_err(|_| {
968                let err_msg = "Failed to initialize the audio context. Reason: The `AUDIO_GRAPH` has already been initialized";
969
970                #[cfg(feature = "logging")]
971                error!(target: "stdout", "{err_msg}");
972
973                LlamaCoreError::InitContext(err_msg.into())
974            })?;
975        }
976    }
977
978    #[cfg(feature = "logging")]
979    info!(target: "stdout", "The audio context has been initialized");
980
981    Ok(())
982}
983
984/// Initialize the piper context
985///
986/// # Arguments
987///
988/// * `voice_model` - Path to the voice model file.
989///
990/// * `voice_config` - Path to the voice config file.
991///
992/// * `espeak_ng_data` - Path to the espeak-ng data directory.
993///
994pub fn init_piper_context(
995    piper_metadata: &PiperMetadata,
996    voice_model: impl AsRef<Path>,
997    voice_config: impl AsRef<Path>,
998    espeak_ng_data: impl AsRef<Path>,
999) -> Result<(), LlamaCoreError> {
1000    #[cfg(feature = "logging")]
1001    info!(target: "stdout", "Initializing the piper context");
1002
1003    let config = serde_json::json!({
1004        "model": voice_model.as_ref().to_owned(),
1005        "config": voice_config.as_ref().to_owned(),
1006        "espeak_data": espeak_ng_data.as_ref().to_owned(),
1007    });
1008
1009    // create and initialize the audio context
1010    let graph = GraphBuilder::new(EngineType::Piper)?
1011        .with_config(piper_metadata.clone())?
1012        .use_cpu()
1013        .build_from_buffer([config.to_string()])?;
1014
1015    PIPER_GRAPH.set(Mutex::new(graph)).map_err(|_| {
1016            let err_msg = "Failed to initialize the piper context. Reason: The `PIPER_GRAPH` has already been initialized";
1017
1018            #[cfg(feature = "logging")]
1019            error!(target: "stdout", "{err_msg}");
1020
1021            LlamaCoreError::InitContext(err_msg.into())
1022        })?;
1023
1024    #[cfg(feature = "logging")]
1025    info!(target: "stdout", "The piper context has been initialized");
1026
1027    Ok(())
1028}