Skip to main content

llama_core/
lib.rs

1//! Llama Core, abbreviated as `llama-core`, defines a set of APIs. Developers can utilize these APIs to build applications based on large models, such as chatbots, RAG, and more.
2
3#![cfg_attr(docsrs, feature(doc_cfg))]
4
5#[cfg(feature = "logging")]
6#[macro_use]
7extern crate log;
8
9pub mod audio;
10pub mod chat;
11pub mod completions;
12pub mod embeddings;
13pub mod error;
14pub mod files;
15pub mod graph;
16pub mod images;
17pub mod metadata;
18pub mod models;
19pub mod tts;
20pub mod utils;
21
22pub use error::LlamaCoreError;
23pub use graph::{EngineType, Graph, GraphBuilder};
24#[cfg(feature = "whisper")]
25use metadata::whisper::WhisperMetadata;
26pub use metadata::{
27    ggml::{GgmlMetadata, GgmlTtsMetadata},
28    piper::PiperMetadata,
29    BaseMetadata,
30};
31use once_cell::sync::OnceCell;
32use std::{
33    collections::HashMap,
34    path::Path,
35    sync::{Mutex, RwLock},
36};
37use utils::{get_output_buffer, RunningMode};
38use wasmedge_stable_diffusion::*;
39
40// key: model_name, value: Graph
41pub(crate) static CHAT_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlMetadata>>>> =
42    OnceCell::new();
43// key: model_name, value: Graph
44pub(crate) static EMBEDDING_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlMetadata>>>> =
45    OnceCell::new();
46// key: model_name, value: Graph
47pub(crate) static TTS_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlTtsMetadata>>>> =
48    OnceCell::new();
49// running mode
50pub(crate) static RUNNING_MODE: OnceCell<RwLock<RunningMode>> = OnceCell::new();
51// stable diffusion context for the text-to-image task
52pub(crate) static SD_TEXT_TO_IMAGE: OnceCell<Mutex<TextToImage>> = OnceCell::new();
53// stable diffusion context for the image-to-image task
54pub(crate) static SD_IMAGE_TO_IMAGE: OnceCell<Mutex<ImageToImage>> = OnceCell::new();
55// context for the audio task
56#[cfg(feature = "whisper")]
57pub(crate) static AUDIO_GRAPH: OnceCell<Mutex<Graph<WhisperMetadata>>> = OnceCell::new();
58// context for the piper task
59pub(crate) static PIPER_GRAPH: OnceCell<Mutex<Graph<PiperMetadata>>> = OnceCell::new();
60
61pub(crate) const MAX_BUFFER_SIZE: usize = 2usize.pow(14) * 15 + 128;
62pub(crate) const OUTPUT_TENSOR: usize = 0;
63const PLUGIN_VERSION: usize = 1;
64
65/// The directory for storing the archives in wasm virtual file system.
66pub const ARCHIVES_DIR: &str = "archives";
67
68/// Initialize the ggml context
69pub fn init_ggml_chat_context(metadata_for_chats: &[GgmlMetadata]) -> Result<(), LlamaCoreError> {
70    #[cfg(feature = "logging")]
71    info!(target: "stdout", "Initializing the core context");
72
73    if metadata_for_chats.is_empty() {
74        let err_msg = "The metadata for chat models is empty";
75
76        #[cfg(feature = "logging")]
77        error!(target: "stdout", "{err_msg}");
78
79        return Err(LlamaCoreError::InitContext(err_msg.into()));
80    }
81
82    let mut chat_graphs = HashMap::new();
83    for metadata in metadata_for_chats {
84        let graph = Graph::new(metadata.clone())?;
85
86        chat_graphs.insert(graph.name().to_string(), graph);
87    }
88
89    // Pre-create per-model locks for each chat model
90    // This ensures locks are ready before any requests arrive
91    for model_name in chat_graphs.keys() {
92        chat::chat_completions::get_or_create_model_lock(model_name)?;
93
94        #[cfg(feature = "logging")]
95        info!(target: "stdout", "Pre-created stream lock for model: {}", model_name);
96    }
97
98    CHAT_GRAPHS.set(Mutex::new(chat_graphs)).map_err(|_| {
99            let err_msg = "Failed to initialize the core context. Reason: The `CHAT_GRAPHS` has already been initialized";
100
101            #[cfg(feature = "logging")]
102            error!(target: "stdout", "{err_msg}");
103
104            LlamaCoreError::InitContext(err_msg.into())
105        })?;
106
107    // set running mode
108    let running_mode = RunningMode::CHAT;
109    match RUNNING_MODE.get() {
110        Some(mode) => {
111            let mut mode = mode.write().unwrap();
112            *mode |= running_mode;
113        }
114        None => {
115            RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
116                let err_msg = "Failed to initialize the chat context. Reason: The `RUNNING_MODE` has already been initialized";
117
118                #[cfg(feature = "logging")]
119                error!(target: "stdout", "{err_msg}");
120
121                LlamaCoreError::InitContext(err_msg.into())
122            })?;
123        }
124    }
125
126    Ok(())
127}
128
129/// Initialize the ggml context
130pub fn init_ggml_embeddings_context(
131    metadata_for_embeddings: &[GgmlMetadata],
132) -> Result<(), LlamaCoreError> {
133    #[cfg(feature = "logging")]
134    info!(target: "stdout", "Initializing the embeddings context");
135
136    if metadata_for_embeddings.is_empty() {
137        let err_msg = "The metadata for chat models is empty";
138
139        #[cfg(feature = "logging")]
140        error!(target: "stdout", "{err_msg}");
141
142        return Err(LlamaCoreError::InitContext(err_msg.into()));
143    }
144
145    let mut embedding_graphs = HashMap::new();
146    for metadata in metadata_for_embeddings {
147        let graph = Graph::new(metadata.clone())?;
148
149        embedding_graphs.insert(graph.name().to_string(), graph);
150    }
151    EMBEDDING_GRAPHS
152            .set(Mutex::new(embedding_graphs))
153            .map_err(|_| {
154                let err_msg = "Failed to initialize the core context. Reason: The `EMBEDDING_GRAPHS` has already been initialized";
155
156                #[cfg(feature = "logging")]
157                error!(target: "stdout", "{err_msg}");
158
159                LlamaCoreError::InitContext(err_msg.into())
160            })?;
161
162    // set running mode
163    let running_mode = RunningMode::EMBEDDINGS;
164    match RUNNING_MODE.get() {
165        Some(mode) => {
166            let mut mode = mode.write().unwrap();
167            *mode |= running_mode;
168        }
169        None => {
170            RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
171                let err_msg = "Failed to initialize the embeddings context. Reason: The `RUNNING_MODE` has already been initialized";
172
173                #[cfg(feature = "logging")]
174                error!(target: "stdout", "{err_msg}");
175
176                LlamaCoreError::InitContext(err_msg.into())
177            })?;
178        }
179    }
180
181    Ok(())
182}
183
184/// Initialize the ggml context for TTS scenarios.
185pub fn init_ggml_tts_context(metadata_for_tts: &[GgmlTtsMetadata]) -> Result<(), LlamaCoreError> {
186    #[cfg(feature = "logging")]
187    info!(target: "stdout", "Initializing the TTS context");
188
189    if metadata_for_tts.is_empty() {
190        let err_msg = "The metadata for tts models is empty";
191
192        #[cfg(feature = "logging")]
193        error!(target: "stdout", "{err_msg}");
194
195        return Err(LlamaCoreError::InitContext(err_msg.into()));
196    }
197
198    let mut tts_graphs = HashMap::new();
199    for metadata in metadata_for_tts {
200        let graph = Graph::new(metadata.clone())?;
201
202        tts_graphs.insert(graph.name().to_string(), graph);
203    }
204    TTS_GRAPHS.set(Mutex::new(tts_graphs)).map_err(|_| {
205        let err_msg = "Failed to initialize the core context. Reason: The `TTS_GRAPHS` has already been initialized";
206
207        #[cfg(feature = "logging")]
208        error!(target: "stdout", "{err_msg}");
209
210        LlamaCoreError::InitContext(err_msg.into())
211    })?;
212
213    // set running mode
214    let running_mode = RunningMode::TTS;
215    match RUNNING_MODE.get() {
216        Some(mode) => {
217            let mut mode = mode.write().unwrap();
218            *mode |= running_mode;
219        }
220        None => {
221            RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
222                let err_msg = "Failed to initialize the embeddings context. Reason: The `RUNNING_MODE` has already been initialized";
223
224                #[cfg(feature = "logging")]
225                error!(target: "stdout", "{err_msg}");
226
227                LlamaCoreError::InitContext(err_msg.into())
228            })?;
229        }
230    }
231
232    Ok(())
233}
234
235/// Get the plugin info
236///
237/// Note that it is required to call `init_core_context` before calling this function.
238pub fn get_plugin_info() -> Result<PluginInfo, LlamaCoreError> {
239    #[cfg(feature = "logging")]
240    debug!(target: "stdout", "Getting the plugin info");
241
242    let running_mode = running_mode()?;
243
244    if running_mode.contains(RunningMode::CHAT) || running_mode.contains(RunningMode::RAG) {
245        let chat_graphs = match CHAT_GRAPHS.get() {
246            Some(chat_graphs) => chat_graphs,
247            None => {
248                let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
249
250                #[cfg(feature = "logging")]
251                error!(target: "stdout", "{err_msg}");
252
253                return Err(LlamaCoreError::Operation(err_msg.into()));
254            }
255        };
256
257        let chat_graphs = chat_graphs.lock().map_err(|e| {
258            let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {e}");
259
260            #[cfg(feature = "logging")]
261            error!(target: "stdout", "{}", &err_msg);
262
263            LlamaCoreError::Operation(err_msg)
264        })?;
265
266        let graph = match chat_graphs.values().next() {
267            Some(graph) => graph,
268            None => {
269                let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
270
271                #[cfg(feature = "logging")]
272                error!(target: "stdout", "{err_msg}");
273
274                return Err(LlamaCoreError::Operation(err_msg.into()));
275            }
276        };
277
278        get_plugin_info_by_graph(graph)
279    } else if running_mode.contains(RunningMode::EMBEDDINGS) {
280        let embedding_graphs = match EMBEDDING_GRAPHS.get() {
281            Some(embedding_graphs) => embedding_graphs,
282            None => {
283                let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
284
285                #[cfg(feature = "logging")]
286                error!(target: "stdout", "{err_msg}");
287
288                return Err(LlamaCoreError::Operation(err_msg.into()));
289            }
290        };
291
292        let embedding_graphs = embedding_graphs.lock().map_err(|e| {
293            let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {e}");
294
295            #[cfg(feature = "logging")]
296            error!(target: "stdout", "{}", &err_msg);
297
298            LlamaCoreError::Operation(err_msg)
299        })?;
300
301        let graph = match embedding_graphs.values().next() {
302            Some(graph) => graph,
303            None => {
304                let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
305
306                #[cfg(feature = "logging")]
307                error!(target: "stdout", "{err_msg}");
308
309                return Err(LlamaCoreError::Operation(err_msg.into()));
310            }
311        };
312
313        get_plugin_info_by_graph(graph)
314    } else if running_mode.contains(RunningMode::TTS) {
315        let tts_graphs = match TTS_GRAPHS.get() {
316            Some(tts_graphs) => tts_graphs,
317            None => {
318                let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
319
320                #[cfg(feature = "logging")]
321                error!(target: "stdout", "{err_msg}");
322
323                return Err(LlamaCoreError::Operation(err_msg.into()));
324            }
325        };
326
327        let tts_graphs = tts_graphs.lock().map_err(|e| {
328            let err_msg = format!("Fail to acquire the lock of `TTS_GRAPHS`. {e}");
329
330            #[cfg(feature = "logging")]
331            error!(target: "stdout", "{}", &err_msg);
332
333            LlamaCoreError::Operation(err_msg)
334        })?;
335
336        let graph = match tts_graphs.values().next() {
337            Some(graph) => graph,
338            None => {
339                let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
340
341                #[cfg(feature = "logging")]
342                error!(target: "stdout", "{err_msg}");
343
344                return Err(LlamaCoreError::Operation(err_msg.into()));
345            }
346        };
347
348        get_plugin_info_by_graph(graph)
349    } else {
350        let err_msg = "RUNNING_MODE is not set";
351
352        #[cfg(feature = "logging")]
353        error!(target: "stdout", "{err_msg}");
354
355        Err(LlamaCoreError::Operation(err_msg.into()))
356    }
357}
358
359fn get_plugin_info_by_graph<M: BaseMetadata + serde::Serialize + Clone + Default>(
360    graph: &Graph<M>,
361) -> Result<PluginInfo, LlamaCoreError> {
362    #[cfg(feature = "logging")]
363    debug!(target: "stdout", "Getting the plugin info by the graph named {}", graph.name());
364
365    // get the plugin metadata
366    let output_buffer = get_output_buffer(graph, PLUGIN_VERSION)?;
367    let metadata: serde_json::Value = serde_json::from_slice(&output_buffer[..]).map_err(|e| {
368        let err_msg = format!("Fail to deserialize the plugin metadata. {e}");
369
370        #[cfg(feature = "logging")]
371        error!(target: "stdout", "{}", &err_msg);
372
373        LlamaCoreError::Operation(err_msg)
374    })?;
375
376    // get build number of the plugin
377    let plugin_build_number = match metadata.get("llama_build_number") {
378        Some(value) => match value.as_u64() {
379            Some(number) => number,
380            None => {
381                let err_msg = "Failed to convert the build number of the plugin to u64";
382
383                #[cfg(feature = "logging")]
384                error!(target: "stdout", "{err_msg}");
385
386                return Err(LlamaCoreError::Operation(err_msg.into()));
387            }
388        },
389        None => {
390            let err_msg = "Metadata does not have the field `llama_build_number`.";
391
392            #[cfg(feature = "logging")]
393            error!(target: "stdout", "{err_msg}");
394
395            return Err(LlamaCoreError::Operation(err_msg.into()));
396        }
397    };
398
399    // get commit id of the plugin
400    let plugin_commit = match metadata.get("llama_commit") {
401        Some(value) => match value.as_str() {
402            Some(commit) => commit,
403            None => {
404                let err_msg = "Failed to convert the commit id of the plugin to string";
405
406                #[cfg(feature = "logging")]
407                error!(target: "stdout", "{err_msg}");
408
409                return Err(LlamaCoreError::Operation(err_msg.into()));
410            }
411        },
412        None => {
413            let err_msg = "Metadata does not have the field `llama_commit`.";
414
415            #[cfg(feature = "logging")]
416            error!(target: "stdout", "{err_msg}");
417
418            return Err(LlamaCoreError::Operation(err_msg.into()));
419        }
420    };
421
422    #[cfg(feature = "logging")]
423    debug!(target: "stdout", "Plugin info: b{plugin_build_number}(commit {plugin_commit})");
424
425    Ok(PluginInfo {
426        build_number: plugin_build_number,
427        commit_id: plugin_commit.to_string(),
428    })
429}
430
431/// Version info of the `wasi-nn_ggml` plugin, including the build number and the commit id.
432#[derive(Debug, Clone)]
433pub struct PluginInfo {
434    pub build_number: u64,
435    pub commit_id: String,
436}
437impl std::fmt::Display for PluginInfo {
438    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
439        write!(
440            f,
441            "wasinn-ggml plugin: b{}(commit {})",
442            self.build_number, self.commit_id
443        )
444    }
445}
446
447/// Return the current running mode.
448pub fn running_mode() -> Result<RunningMode, LlamaCoreError> {
449    #[cfg(feature = "logging")]
450    debug!(target: "stdout", "Get the running mode.");
451
452    match RUNNING_MODE.get() {
453        Some(mode) => match mode.read() {
454            Ok(mode) => Ok(*mode),
455            Err(e) => {
456                let err_msg = format!("Fail to get the underlying value of `RUNNING_MODE`. {e}");
457
458                #[cfg(feature = "logging")]
459                error!(target: "stdout", "{err_msg}");
460
461                Err(LlamaCoreError::Operation(err_msg))
462            }
463        },
464        None => {
465            let err_msg = "Fail to get the underlying value of `RUNNING_MODE`.";
466
467            #[cfg(feature = "logging")]
468            error!(target: "stdout", "{err_msg}");
469
470            Err(LlamaCoreError::Operation(err_msg.into()))
471        }
472    }
473}
474
475/// Initialize the stable-diffusion context with the given full diffusion model
476///
477/// # Arguments
478///
479/// * `model_file` - Path to the stable diffusion model file.
480///
481/// * `lora_model_dir` - Path to the Lora model directory.
482///
483/// * `controlnet_path` - Path to the controlnet model file.
484///
485/// * `controlnet_on_cpu` - Whether to run the controlnet on CPU.
486///
487/// * `clip_on_cpu` - Whether to run the CLIP on CPU.
488///
489/// * `vae_on_cpu` - Whether to run the VAE on CPU.
490///
491/// * `n_threads` - Number of threads to use.
492///
493/// * `task` - The task type to perform.
494#[allow(clippy::too_many_arguments)]
495pub fn init_sd_context_with_full_model(
496    model_file: impl AsRef<str>,
497    lora_model_dir: Option<&str>,
498    controlnet_path: Option<&str>,
499    controlnet_on_cpu: bool,
500    clip_on_cpu: bool,
501    vae_on_cpu: bool,
502    n_threads: i32,
503    task: StableDiffusionTask,
504) -> Result<(), LlamaCoreError> {
505    #[cfg(feature = "logging")]
506    info!(target: "stdout", "Initializing the stable diffusion context with the full model");
507
508    let control_net_on_cpu = match controlnet_path {
509        Some(path) if !path.is_empty() => controlnet_on_cpu,
510        _ => false,
511    };
512
513    // create the stable diffusion context for the text-to-image task
514    if task == StableDiffusionTask::Full || task == StableDiffusionTask::TextToImage {
515        let sd = SDBuidler::new(Task::TextToImage, model_file.as_ref())
516            .map_err(|e| {
517                let err_msg =
518                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
519
520                #[cfg(feature = "logging")]
521                error!(target: "stdout", "{err_msg}");
522
523                LlamaCoreError::InitContext(err_msg)
524            })?
525            .with_lora_model_dir(lora_model_dir.unwrap_or_default())
526            .map_err(|e| {
527                let err_msg =
528                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
529
530                #[cfg(feature = "logging")]
531                error!(target: "stdout", "{err_msg}");
532
533                LlamaCoreError::InitContext(err_msg)
534            })?
535            .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
536            .map_err(|e| {
537                let err_msg =
538                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
539
540                #[cfg(feature = "logging")]
541                error!(target: "stdout", "{err_msg}");
542
543                LlamaCoreError::InitContext(err_msg)
544            })?
545            .clip_on_cpu(clip_on_cpu)
546            .vae_on_cpu(vae_on_cpu)
547            .with_n_threads(n_threads)
548            .build();
549
550        #[cfg(feature = "logging")]
551        info!(target: "stdout", "sd: {:?}", &sd);
552
553        let ctx = sd.create_context().map_err(|e| {
554            let err_msg = format!("Fail to create the context. {e}");
555
556            #[cfg(feature = "logging")]
557            error!(target: "stdout", "{}", &err_msg);
558
559            LlamaCoreError::InitContext(err_msg)
560        })?;
561
562        let ctx = match ctx {
563            Context::TextToImage(ctx) => ctx,
564            _ => {
565                let err_msg = "Fail to get the context for the text-to-image task";
566
567                #[cfg(feature = "logging")]
568                error!(target: "stdout", "{err_msg}");
569
570                return Err(LlamaCoreError::InitContext(err_msg.into()));
571            }
572        };
573
574        #[cfg(feature = "logging")]
575        info!(target: "stdout", "sd text_to_image context: {:?}", &ctx);
576
577        SD_TEXT_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
578        let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_TEXT_TO_IMAGE` has already been initialized";
579
580        #[cfg(feature = "logging")]
581        error!(target: "stdout", "{err_msg}");
582
583        LlamaCoreError::InitContext(err_msg.into())
584    })?;
585
586        #[cfg(feature = "logging")]
587        info!(target: "stdout", "The stable diffusion text-to-image context has been initialized");
588    }
589
590    // create the stable diffusion context for the image-to-image task
591    if task == StableDiffusionTask::Full || task == StableDiffusionTask::ImageToImage {
592        let sd = SDBuidler::new(Task::ImageToImage, model_file.as_ref())
593            .map_err(|e| {
594                let err_msg =
595                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
596
597                #[cfg(feature = "logging")]
598                error!(target: "stdout", "{err_msg}");
599
600                LlamaCoreError::InitContext(err_msg)
601            })?
602            .with_lora_model_dir(lora_model_dir.unwrap_or_default())
603            .map_err(|e| {
604                let err_msg =
605                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
606
607                #[cfg(feature = "logging")]
608                error!(target: "stdout", "{err_msg}");
609
610                LlamaCoreError::InitContext(err_msg)
611            })?
612            .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
613            .map_err(|e| {
614                let err_msg =
615                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
616
617                #[cfg(feature = "logging")]
618                error!(target: "stdout", "{err_msg}");
619
620                LlamaCoreError::InitContext(err_msg)
621            })?
622            .clip_on_cpu(clip_on_cpu)
623            .vae_on_cpu(vae_on_cpu)
624            .with_n_threads(n_threads)
625            .build();
626
627        #[cfg(feature = "logging")]
628        info!(target: "stdout", "sd: {:?}", &sd);
629
630        let ctx = sd.create_context().map_err(|e| {
631            let err_msg = format!("Fail to create the context. {e}");
632
633            #[cfg(feature = "logging")]
634            error!(target: "stdout", "{}", &err_msg);
635
636            LlamaCoreError::InitContext(err_msg)
637        })?;
638
639        let ctx = match ctx {
640            Context::ImageToImage(ctx) => ctx,
641            _ => {
642                let err_msg = "Fail to get the context for the image-to-image task";
643
644                #[cfg(feature = "logging")]
645                error!(target: "stdout", "{err_msg}");
646
647                return Err(LlamaCoreError::InitContext(err_msg.into()));
648            }
649        };
650
651        #[cfg(feature = "logging")]
652        info!(target: "stdout", "sd image_to_image context: {:?}", &ctx);
653
654        SD_IMAGE_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
655            let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_IMAGE_TO_IMAGE` has already been initialized";
656
657            #[cfg(feature = "logging")]
658            error!(target: "stdout", "{err_msg}");
659
660            LlamaCoreError::InitContext(err_msg.into())
661        })?;
662
663        #[cfg(feature = "logging")]
664        info!(target: "stdout", "The stable diffusion image-to-image context has been initialized");
665    }
666
667    Ok(())
668}
669
670/// Initialize the stable-diffusion context with the given standalone diffusion model
671///
672/// # Arguments
673///
674/// * `model_file` - Path to the standalone diffusion model file.
675///
676/// * `vae` - Path to the VAE model file.
677///
678/// * `clip_l` - Path to the CLIP model file.
679///
680/// * `t5xxl` - Path to the T5-XXL model file.
681///
682/// * `lora_model_dir` - Path to the Lora model directory.
683///
684/// * `controlnet_path` - Path to the controlnet model file.
685///
686/// * `controlnet_on_cpu` - Whether to run the controlnet on CPU.
687///
688/// * `clip_on_cpu` - Whether to run the CLIP on CPU.
689///
690/// * `vae_on_cpu` - Whether to run the VAE on CPU.
691///
692/// * `n_threads` - Number of threads to use.
693///
694/// * `task` - The task type to perform.
695#[allow(clippy::too_many_arguments)]
696pub fn init_sd_context_with_standalone_model(
697    model_file: impl AsRef<str>,
698    vae: impl AsRef<str>,
699    clip_l: impl AsRef<str>,
700    t5xxl: impl AsRef<str>,
701    lora_model_dir: Option<&str>,
702    controlnet_path: Option<&str>,
703    controlnet_on_cpu: bool,
704    clip_on_cpu: bool,
705    vae_on_cpu: bool,
706    n_threads: i32,
707    task: StableDiffusionTask,
708) -> Result<(), LlamaCoreError> {
709    #[cfg(feature = "logging")]
710    info!(target: "stdout", "Initializing the stable diffusion context with the standalone diffusion model");
711
712    let control_net_on_cpu = match controlnet_path {
713        Some(path) if !path.is_empty() => controlnet_on_cpu,
714        _ => false,
715    };
716
717    // create the stable diffusion context for the text-to-image task
718    if task == StableDiffusionTask::Full || task == StableDiffusionTask::TextToImage {
719        let sd = SDBuidler::new_with_standalone_model(Task::TextToImage, model_file.as_ref())
720            .map_err(|e| {
721                let err_msg =
722                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
723
724                #[cfg(feature = "logging")]
725                error!(target: "stdout", "{err_msg}");
726
727                LlamaCoreError::InitContext(err_msg)
728            })?
729            .with_vae_path(vae.as_ref())
730            .map_err(|e| {
731                let err_msg =
732                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
733
734                #[cfg(feature = "logging")]
735                error!(target: "stdout", "{err_msg}");
736
737                LlamaCoreError::InitContext(err_msg)
738            })?
739            .with_clip_l_path(clip_l.as_ref())
740            .map_err(|e| {
741                let err_msg =
742                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
743
744                #[cfg(feature = "logging")]
745                error!(target: "stdout", "{err_msg}");
746
747                LlamaCoreError::InitContext(err_msg)
748            })?
749            .with_t5xxl_path(t5xxl.as_ref())
750            .map_err(|e| {
751                let err_msg =
752                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
753
754                #[cfg(feature = "logging")]
755                error!(target: "stdout", "{err_msg}");
756
757                LlamaCoreError::InitContext(err_msg)
758            })?
759            .with_lora_model_dir(lora_model_dir.unwrap_or_default())
760            .map_err(|e| {
761                let err_msg =
762                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
763
764                #[cfg(feature = "logging")]
765                error!(target: "stdout", "{err_msg}");
766
767                LlamaCoreError::InitContext(err_msg)
768            })?
769            .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
770            .map_err(|e| {
771                let err_msg =
772                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
773
774                #[cfg(feature = "logging")]
775                error!(target: "stdout", "{err_msg}");
776
777                LlamaCoreError::InitContext(err_msg)
778            })?
779            .clip_on_cpu(clip_on_cpu)
780            .vae_on_cpu(vae_on_cpu)
781            .with_n_threads(n_threads)
782            .build();
783
784        #[cfg(feature = "logging")]
785        info!(target: "stdout", "sd: {:?}", &sd);
786
787        let ctx = sd.create_context().map_err(|e| {
788            let err_msg = format!("Fail to create the context. {e}");
789
790            #[cfg(feature = "logging")]
791            error!(target: "stdout", "{}", &err_msg);
792
793            LlamaCoreError::InitContext(err_msg)
794        })?;
795
796        let ctx = match ctx {
797            Context::TextToImage(ctx) => ctx,
798            _ => {
799                let err_msg = "Fail to get the context for the text-to-image task";
800
801                #[cfg(feature = "logging")]
802                error!(target: "stdout", "{err_msg}");
803
804                return Err(LlamaCoreError::InitContext(err_msg.into()));
805            }
806        };
807
808        #[cfg(feature = "logging")]
809        info!(target: "stdout", "sd text_to_image context: {:?}", &ctx);
810
811        SD_TEXT_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
812            let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_TEXT_TO_IMAGE` has already been initialized";
813
814            #[cfg(feature = "logging")]
815            error!(target: "stdout", "{err_msg}");
816
817            LlamaCoreError::InitContext(err_msg.into())
818        })?;
819
820        #[cfg(feature = "logging")]
821        info!(target: "stdout", "The stable diffusion text-to-image context has been initialized");
822    }
823
824    // create the stable diffusion context for the image-to-image task
825    if task == StableDiffusionTask::Full || task == StableDiffusionTask::ImageToImage {
826        let sd = SDBuidler::new_with_standalone_model(Task::ImageToImage, model_file.as_ref())
827            .map_err(|e| {
828                let err_msg =
829                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
830
831                #[cfg(feature = "logging")]
832                error!(target: "stdout", "{err_msg}");
833
834                LlamaCoreError::InitContext(err_msg)
835            })?
836            .with_vae_path(vae.as_ref())
837            .map_err(|e| {
838                let err_msg =
839                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
840
841                #[cfg(feature = "logging")]
842                error!(target: "stdout", "{err_msg}");
843
844                LlamaCoreError::InitContext(err_msg)
845            })?
846            .with_clip_l_path(clip_l.as_ref())
847            .map_err(|e| {
848                let err_msg =
849                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
850
851                #[cfg(feature = "logging")]
852                error!(target: "stdout", "{err_msg}");
853
854                LlamaCoreError::InitContext(err_msg)
855            })?
856            .with_t5xxl_path(t5xxl.as_ref())
857            .map_err(|e| {
858                let err_msg =
859                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
860
861                #[cfg(feature = "logging")]
862                error!(target: "stdout", "{err_msg}");
863
864                LlamaCoreError::InitContext(err_msg)
865            })?
866            .with_lora_model_dir(lora_model_dir.unwrap_or_default())
867            .map_err(|e| {
868                let err_msg =
869                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
870
871                #[cfg(feature = "logging")]
872                error!(target: "stdout", "{err_msg}");
873
874                LlamaCoreError::InitContext(err_msg)
875            })?
876            .use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
877            .map_err(|e| {
878                let err_msg =
879                    format!("Failed to initialize the stable diffusion context. Reason: {e}");
880
881                #[cfg(feature = "logging")]
882                error!(target: "stdout", "{err_msg}");
883
884                LlamaCoreError::InitContext(err_msg)
885            })?
886            .clip_on_cpu(clip_on_cpu)
887            .vae_on_cpu(vae_on_cpu)
888            .with_n_threads(n_threads)
889            .build();
890
891        #[cfg(feature = "logging")]
892        info!(target: "stdout", "sd: {:?}", &sd);
893
894        let ctx = sd.create_context().map_err(|e| {
895            let err_msg = format!("Fail to create the context. {e}");
896
897            #[cfg(feature = "logging")]
898            error!(target: "stdout", "{}", &err_msg);
899
900            LlamaCoreError::InitContext(err_msg)
901        })?;
902
903        let ctx = match ctx {
904            Context::ImageToImage(ctx) => ctx,
905            _ => {
906                let err_msg = "Fail to get the context for the image-to-image task";
907
908                #[cfg(feature = "logging")]
909                error!(target: "stdout", "{err_msg}");
910
911                return Err(LlamaCoreError::InitContext(err_msg.into()));
912            }
913        };
914
915        #[cfg(feature = "logging")]
916        info!(target: "stdout", "sd image_to_image context: {:?}", &ctx);
917
918        SD_IMAGE_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
919        let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_IMAGE_TO_IMAGE` has already been initialized";
920
921        #[cfg(feature = "logging")]
922        error!(target: "stdout", "{err_msg}");
923
924        LlamaCoreError::InitContext(err_msg.into())
925    })?;
926
927        #[cfg(feature = "logging")]
928        info!(target: "stdout", "The stable diffusion image-to-image context has been initialized");
929    }
930
931    Ok(())
932}
933
934/// The task type of the stable diffusion context
935#[derive(Clone, Debug, Copy, PartialEq, Eq)]
936pub enum StableDiffusionTask {
937    /// `text_to_image` context
938    TextToImage,
939    /// `image_to_image` context
940    ImageToImage,
941    /// Both `text_to_image` and `image_to_image` contexts
942    Full,
943}
944
945/// Initialize the whisper context
946#[cfg(feature = "whisper")]
947pub fn init_whisper_context(whisper_metadata: &WhisperMetadata) -> Result<(), LlamaCoreError> {
948    // create and initialize the audio context
949    let graph = GraphBuilder::new(EngineType::Whisper)?
950        .with_config(whisper_metadata.clone())?
951        .use_cpu()
952        .build_from_files([&whisper_metadata.model_path])?;
953
954    match AUDIO_GRAPH.get() {
955        Some(mutex_graph) => {
956            #[cfg(feature = "logging")]
957            info!(target: "stdout", "Re-initialize the audio context");
958
959            match mutex_graph.lock() {
960                Ok(mut locked_graph) => *locked_graph = graph,
961                Err(e) => {
962                    let err_msg = format!("Failed to lock the graph. Reason: {e}");
963
964                    #[cfg(feature = "logging")]
965                    error!(target: "stdout", "{err_msg}");
966
967                    return Err(LlamaCoreError::InitContext(err_msg));
968                }
969            }
970        }
971        None => {
972            #[cfg(feature = "logging")]
973            info!(target: "stdout", "Initialize the audio context");
974
975            AUDIO_GRAPH.set(Mutex::new(graph)).map_err(|_| {
976                let err_msg = "Failed to initialize the audio context. Reason: The `AUDIO_GRAPH` has already been initialized";
977
978                #[cfg(feature = "logging")]
979                error!(target: "stdout", "{err_msg}");
980
981                LlamaCoreError::InitContext(err_msg.into())
982            })?;
983        }
984    }
985
986    #[cfg(feature = "logging")]
987    info!(target: "stdout", "The audio context has been initialized");
988
989    Ok(())
990}
991
992/// Initialize the piper context
993///
994/// # Arguments
995///
996/// * `voice_model` - Path to the voice model file.
997///
998/// * `voice_config` - Path to the voice config file.
999///
1000/// * `espeak_ng_data` - Path to the espeak-ng data directory.
1001///
1002pub fn init_piper_context(
1003    piper_metadata: &PiperMetadata,
1004    voice_model: impl AsRef<Path>,
1005    voice_config: impl AsRef<Path>,
1006    espeak_ng_data: impl AsRef<Path>,
1007) -> Result<(), LlamaCoreError> {
1008    #[cfg(feature = "logging")]
1009    info!(target: "stdout", "Initializing the piper context");
1010
1011    let config = serde_json::json!({
1012        "model": voice_model.as_ref().to_owned(),
1013        "config": voice_config.as_ref().to_owned(),
1014        "espeak_data": espeak_ng_data.as_ref().to_owned(),
1015    });
1016
1017    // create and initialize the audio context
1018    let graph = GraphBuilder::new(EngineType::Piper)?
1019        .with_config(piper_metadata.clone())?
1020        .use_cpu()
1021        .build_from_buffer([config.to_string()])?;
1022
1023    PIPER_GRAPH.set(Mutex::new(graph)).map_err(|_| {
1024            let err_msg = "Failed to initialize the piper context. Reason: The `PIPER_GRAPH` has already been initialized";
1025
1026            #[cfg(feature = "logging")]
1027            error!(target: "stdout", "{err_msg}");
1028
1029            LlamaCoreError::InitContext(err_msg.into())
1030        })?;
1031
1032    #[cfg(feature = "logging")]
1033    info!(target: "stdout", "The piper context has been initialized");
1034
1035    Ok(())
1036}