llama_core/
audio.rs

1use crate::{error::LlamaCoreError, utils::set_tensor_data, MAX_BUFFER_SIZE, PIPER_GRAPH};
2#[cfg(feature = "whisper")]
3use crate::{metadata::whisper::WhisperMetadata, AUDIO_GRAPH};
4use endpoints::audio::speech::SpeechRequest;
5#[cfg(feature = "whisper")]
6use endpoints::audio::{
7    transcription::{TranscriptionObject, TranscriptionRequest},
8    translation::{TranslationObject, TranslationRequest},
9};
10#[cfg(feature = "whisper")]
11use std::path::Path;
12
13/// Transcribe audio into the input language.
14#[cfg(feature = "whisper")]
15#[cfg_attr(docsrs, doc(cfg(feature = "whisper")))]
16pub async fn audio_transcriptions(
17    request: TranscriptionRequest,
18) -> Result<TranscriptionObject, LlamaCoreError> {
19    let res = transcribe_audio(request).await;
20
21    #[cfg(feature = "logging")]
22    info!(target: "stdout", "Reset the model metadata.");
23
24    // reset the model metadata
25    reset_model_metadata()?;
26
27    res
28}
29
30#[cfg(feature = "whisper")]
31async fn transcribe_audio(
32    request: TranscriptionRequest,
33) -> Result<TranscriptionObject, LlamaCoreError> {
34    #[cfg(feature = "logging")]
35    info!(target: "stdout", "processing audio transcription request");
36
37    let graph = match AUDIO_GRAPH.get() {
38        Some(graph) => graph,
39        None => {
40            let err_msg = "The AUDIO_GRAPH is not initialized.";
41
42            #[cfg(feature = "logging")]
43            error!(target: "stdout", "{}", &err_msg);
44
45            return Err(LlamaCoreError::Operation(err_msg.to_owned()));
46        }
47    };
48
49    let mut graph = match graph.lock() {
50        Ok(graph) => graph,
51        Err(e) => {
52            let err_msg = format!("Failed to lock the graph. {}", e);
53
54            #[cfg(feature = "logging")]
55            error!(target: "stdout", "{}", &err_msg);
56
57            return Err(LlamaCoreError::Operation(err_msg));
58        }
59    };
60
61    // check if the model metadata should be updated
62    {
63        let mut should_update = false;
64
65        let mut metadata = graph.metadata.clone();
66
67        #[cfg(feature = "logging")]
68        info!(target: "stdout", "current metadata: {:?}", &metadata);
69
70        #[cfg(feature = "logging")]
71        info!(target: "stdout", "Check model metadata.");
72
73        // check `translate` field
74        if metadata.translate {
75            // update the metadata
76            metadata.translate = false;
77
78            if !should_update {
79                should_update = true;
80            }
81        }
82
83        // check `language` field
84        if let Some(language) = &request.language {
85            if *language != metadata.language {
86                // update the metadata
87                metadata.language = language.clone();
88
89                if !should_update {
90                    should_update = true;
91                }
92            }
93        }
94
95        // check `detect_language` field
96        if let Some(detect_language) = &request.detect_language {
97            if *detect_language != metadata.detect_language {
98                // update the metadata
99                metadata.detect_language = *detect_language;
100
101                if !should_update {
102                    should_update = true;
103                }
104            }
105        }
106
107        // check `offset_time` field
108        if let Some(offset_time) = &request.offset_time {
109            if *offset_time != metadata.offset_time {
110                // update the metadata
111                metadata.offset_time = *offset_time;
112
113                if !should_update {
114                    should_update = true;
115                }
116            }
117        }
118
119        // check `duration` field
120        if let Some(duration) = &request.duration {
121            if *duration != metadata.duration {
122                // update the metadata
123                metadata.duration = *duration;
124
125                if !should_update {
126                    should_update = true;
127                }
128            }
129        }
130
131        // check `max_context` field
132        if let Some(max_context) = &request.max_context {
133            if *max_context != metadata.max_context {
134                // update the metadata
135                metadata.max_context = *max_context;
136
137                if !should_update {
138                    should_update = true;
139                }
140            }
141        }
142
143        // check `max_len` field
144        if let Some(max_len) = &request.max_len {
145            if *max_len != metadata.max_len {
146                // update the metadata
147                metadata.max_len = *max_len;
148
149                if !should_update {
150                    should_update = true;
151                }
152            }
153        }
154
155        // check `temperature` field
156        if let Some(temperature) = &request.temperature {
157            if *temperature != metadata.temperature {
158                // update the metadata
159                metadata.temperature = *temperature;
160
161                if !should_update {
162                    should_update = true;
163                }
164            }
165        }
166
167        // check `split_on_word` field
168        if let Some(split_on_word) = &request.split_on_word {
169            if *split_on_word != metadata.split_on_word {
170                // update the metadata
171                metadata.split_on_word = *split_on_word;
172
173                if !should_update {
174                    should_update = true;
175                }
176            }
177        }
178
179        // check `prompt` field
180        if let Some(prompt) = &request.prompt {
181            if !prompt.is_empty() {
182                match &metadata.prompt {
183                    Some(p) => {
184                        if *p != *prompt {
185                            metadata.prompt = Some(prompt.clone());
186
187                            if !should_update {
188                                should_update = true;
189                            }
190                        }
191                    }
192                    None => {
193                        metadata.prompt = Some(prompt.clone());
194                        if !should_update {
195                            should_update = true;
196                        }
197                    }
198                }
199            }
200        }
201
202        if should_update {
203            #[cfg(feature = "logging")]
204            info!(target: "stdout", "Set the metadata to the model.");
205
206            #[cfg(feature = "logging")]
207            debug!(target: "stdout", "new metadata: {}", serde_json::to_string(&metadata).unwrap());
208
209            match serde_json::to_string(&metadata) {
210                Ok(config) => {
211                    // update metadata
212                    set_tensor_data(&mut graph, 1, config.as_bytes(), [1])?;
213
214                    #[cfg(feature = "logging")]
215                    info!(target: "stdout", "metadata updated");
216                }
217                Err(e) => {
218                    let err_msg = format!("Fail to serialize metadata to a JSON string. {}", e);
219
220                    #[cfg(feature = "logging")]
221                    error!(target: "stdout", "{}", &err_msg);
222
223                    return Err(LlamaCoreError::Operation(err_msg));
224                }
225            };
226        }
227    }
228
229    let path = Path::new("archives")
230        .join(&request.file.id)
231        .join(&request.file.filename);
232
233    #[cfg(feature = "logging")]
234    info!(target: "stdout", "audio file path: {:?}", &path);
235
236    // load the audio waveform
237    let wav_buf = load_audio_waveform(path)?;
238
239    #[cfg(feature = "logging")]
240    info!(target: "stdout", "read input tensor, size in bytes: {}", wav_buf.len());
241
242    // set the input tensor
243    #[cfg(feature = "logging")]
244    info!(target: "stdout", "Feed the audio data to the model.");
245    set_tensor_data(&mut graph, 0, &wav_buf, [1, wav_buf.len()])?;
246
247    // compute the graph
248    #[cfg(feature = "logging")]
249    info!(target: "stdout", "Transcribe audio to text.");
250    if let Err(e) = graph.compute() {
251        let err_msg = format!("Failed to compute the graph. {}", e);
252
253        #[cfg(feature = "logging")]
254        error!(target: "stdout", "{}", &err_msg);
255
256        return Err(LlamaCoreError::Operation(err_msg));
257    }
258
259    // get the output tensor
260    #[cfg(feature = "logging")]
261    info!(target: "stdout", "[INFO] Retrieve the transcription data.");
262
263    // Retrieve the output.
264    let mut output_buffer = vec![0u8; MAX_BUFFER_SIZE];
265    let output_size = graph.get_output(0, &mut output_buffer).map_err(|e| {
266        let err_msg = format!("Failed to get the output tensor. {}", e);
267
268        #[cfg(feature = "logging")]
269        error!(target: "stdout", "{}", &err_msg);
270
271        LlamaCoreError::Operation(err_msg)
272    })?;
273
274    #[cfg(feature = "logging")]
275    info!(target: "stdout", "Output buffer size: {}", output_size);
276
277    // decode the output buffer
278    #[cfg(feature = "logging")]
279    info!(target: "stdout", "Decode the transcription data to plain text.");
280
281    let text = String::from_utf8_lossy(&output_buffer[..output_size]);
282
283    #[cfg(feature = "logging")]
284    info!(target: "stdout", "raw transcription text:\n{}", &text);
285
286    let obj = TranscriptionObject {
287        text: text.trim().to_owned(),
288    };
289
290    #[cfg(feature = "logging")]
291    info!(target: "stdout", "End of the audio transcription.");
292
293    Ok(obj)
294}
295
296#[cfg(feature = "whisper")]
297fn load_audio_waveform(filename: impl AsRef<std::path::Path>) -> Result<Vec<u8>, LlamaCoreError> {
298    std::fs::read(filename)
299        .map_err(|e| {
300            let err_msg = format!("Failed to read the input tensor. {}", e);
301
302            #[cfg(feature = "logging")]
303            error!(target: "stdout", "{}", &err_msg);
304
305            LlamaCoreError::Operation(err_msg)
306        })
307        .map_err(|e| LlamaCoreError::Operation(e.to_string()))
308}
309
310fn _remove_blank_audio(input: &str) -> String {
311    let blank_audio_marker = "[BLANK_AUDIO]";
312
313    // Split the input string by newline and filter out segments containing [BLANK_AUDIO]
314    let filtered_segments: Vec<&str> = input
315        .lines()
316        .filter(|segment| !segment.contains(blank_audio_marker))
317        .collect();
318
319    // Rejoin the filtered segments with newline
320    filtered_segments.join("\n")
321}
322
323/// Translate audio into the target language
324#[cfg(feature = "whisper")]
325#[cfg_attr(docsrs, doc(cfg(feature = "whisper")))]
326pub async fn audio_translations(
327    request: TranslationRequest,
328) -> Result<TranslationObject, LlamaCoreError> {
329    let res = translate_audio(request).await;
330
331    #[cfg(feature = "logging")]
332    info!(target: "stdout", "Reset the model metadata.");
333
334    // reset the model metadata
335    reset_model_metadata()?;
336
337    res
338}
339
340#[cfg(feature = "whisper")]
341async fn translate_audio(request: TranslationRequest) -> Result<TranslationObject, LlamaCoreError> {
342    #[cfg(feature = "logging")]
343    info!(target: "stdout", "processing audio translation request");
344
345    let graph = match AUDIO_GRAPH.get() {
346        Some(graph) => graph,
347        None => {
348            let err_msg = "The AUDIO_GRAPH is not initialized.";
349
350            #[cfg(feature = "logging")]
351            error!(target: "stdout", "{}", &err_msg);
352
353            return Err(LlamaCoreError::Operation(err_msg.to_owned()));
354        }
355    };
356
357    let mut graph = match graph.lock() {
358        Ok(graph) => graph,
359        Err(e) => {
360            let err_msg = format!("Failed to lock the graph. {}", e);
361
362            #[cfg(feature = "logging")]
363            error!(target: "stdout", "{}", &err_msg);
364
365            return Err(LlamaCoreError::Operation(err_msg));
366        }
367    };
368
369    // check if the model metadata should be updated
370    {
371        let mut should_update = false;
372
373        let mut metadata = graph.metadata.clone();
374
375        #[cfg(feature = "logging")]
376        info!(target: "stdout", "current metadata: {:?}", &metadata);
377
378        #[cfg(feature = "logging")]
379        info!(target: "stdout", "Check model metadata.");
380
381        // check `translate` field
382        if !metadata.translate {
383            metadata.translate = true;
384
385            if !should_update {
386                should_update = true;
387            }
388        }
389
390        // check `language` field
391        if let Some(language) = &request.language {
392            if *language != metadata.language {
393                metadata.language = language.clone();
394
395                if !should_update {
396                    should_update = true;
397                }
398            }
399        }
400
401        // check `detect_language` field
402        if let Some(detect_language) = &request.detect_language {
403            if *detect_language != metadata.detect_language {
404                metadata.detect_language = *detect_language;
405
406                if !should_update {
407                    should_update = true;
408                }
409            }
410        }
411
412        // check `offset_time` field
413        if let Some(offset_time) = &request.offset_time {
414            if *offset_time != metadata.offset_time {
415                // update the metadata
416                metadata.offset_time = *offset_time;
417
418                if !should_update {
419                    should_update = true;
420                }
421            }
422        }
423
424        // check `duration` field
425        if let Some(duration) = &request.duration {
426            if *duration != metadata.duration {
427                metadata.duration = *duration;
428
429                if !should_update {
430                    should_update = true;
431                }
432            }
433        }
434
435        // check `max_context` field
436        if let Some(max_context) = &request.max_context {
437            if *max_context != metadata.max_context {
438                metadata.max_context = *max_context;
439
440                if !should_update {
441                    should_update = true;
442                }
443            }
444        }
445
446        // check `max_len` field
447        if let Some(max_len) = &request.max_len {
448            if *max_len != metadata.max_len {
449                metadata.max_len = *max_len;
450
451                if !should_update {
452                    should_update = true;
453                }
454            }
455        }
456
457        // check `temperature` field
458        if let Some(temperature) = &request.temperature {
459            if *temperature != metadata.temperature {
460                metadata.temperature = *temperature;
461
462                if !should_update {
463                    should_update = true;
464                }
465            }
466        }
467
468        // check `split_on_word` field
469        if let Some(split_on_word) = &request.split_on_word {
470            if *split_on_word != metadata.split_on_word {
471                metadata.split_on_word = *split_on_word;
472
473                if !should_update {
474                    should_update = true;
475                }
476            }
477        }
478
479        // check `prompt` field
480        if let Some(prompt) = &request.prompt {
481            if !prompt.is_empty() {
482                match &metadata.prompt {
483                    Some(p) => {
484                        if *p != *prompt {
485                            metadata.prompt = Some(prompt.clone());
486
487                            if !should_update {
488                                should_update = true;
489                            }
490                        }
491                    }
492                    None => {
493                        metadata.prompt = Some(prompt.clone());
494                        if !should_update {
495                            should_update = true;
496                        }
497                    }
498                }
499            }
500        }
501
502        if should_update {
503            #[cfg(feature = "logging")]
504            info!(target: "stdout", "Set the metadata to the model.");
505
506            #[cfg(feature = "logging")]
507            debug!(target: "stdout", "new metadata: {}", serde_json::to_string(&metadata).unwrap());
508
509            match serde_json::to_string(&metadata) {
510                Ok(config) => {
511                    // update metadata
512                    set_tensor_data(&mut graph, 1, config.as_bytes(), [1])?;
513                }
514                Err(e) => {
515                    let err_msg = format!("Fail to serialize metadata to a JSON string. {}", e);
516
517                    #[cfg(feature = "logging")]
518                    error!(target: "stdout", "{}", &err_msg);
519
520                    return Err(LlamaCoreError::Operation(err_msg));
521                }
522            };
523        }
524    }
525
526    let path = Path::new("archives")
527        .join(&request.file.id)
528        .join(&request.file.filename);
529
530    #[cfg(feature = "logging")]
531    info!(target: "stdout", "audio file path: {:?}", &path);
532
533    // load the audio waveform
534    let wav_buf = load_audio_waveform(path)?;
535
536    #[cfg(feature = "logging")]
537    info!(target: "stdout", "read input tensor, size in bytes: {}", wav_buf.len());
538
539    // set the input tensor
540    #[cfg(feature = "logging")]
541    info!(target: "stdout", "feed the audio data to the model.");
542    set_tensor_data(&mut graph, 0, &wav_buf, [1, wav_buf.len()])?;
543
544    // compute the graph
545    #[cfg(feature = "logging")]
546    info!(target: "stdout", "translate audio to text.");
547    if let Err(e) = graph.compute() {
548        let err_msg = format!("Failed to compute the graph. {}", e);
549
550        #[cfg(feature = "logging")]
551        error!(target: "stdout", "{}", &err_msg);
552
553        return Err(LlamaCoreError::Operation(err_msg));
554    }
555
556    // get the output tensor
557    #[cfg(feature = "logging")]
558    info!(target: "stdout", "[INFO] retrieve the translation data.");
559
560    // Retrieve the output.
561    let mut output_buffer = vec![0u8; MAX_BUFFER_SIZE];
562    let output_size = graph.get_output(0, &mut output_buffer).map_err(|e| {
563        let err_msg = format!("Failed to get the output tensor. {}", e);
564
565        #[cfg(feature = "logging")]
566        error!(target: "stdout", "{}", &err_msg);
567
568        LlamaCoreError::Operation(err_msg)
569    })?;
570
571    #[cfg(feature = "logging")]
572    info!(target: "stdout", "output buffer size: {}", output_size);
573
574    // decode the output buffer
575    #[cfg(feature = "logging")]
576    info!(target: "stdout", "decode the translation data to plain text.");
577
578    let text = String::from_utf8_lossy(&output_buffer[..output_size]);
579
580    #[cfg(feature = "logging")]
581    info!(target: "stdout", "raw translation text:\n{}", &text);
582
583    let obj = TranslationObject {
584        text: text.trim().to_owned(),
585    };
586
587    #[cfg(feature = "logging")]
588    info!(target: "stdout", "End of the audio translation.");
589
590    #[cfg(feature = "logging")]
591    info!(target: "stdout", "Reset the model metadata.");
592
593    Ok(obj)
594}
595
596/// Generate audio from the input text.
597pub async fn create_speech(request: SpeechRequest) -> Result<Vec<u8>, LlamaCoreError> {
598    #[cfg(feature = "logging")]
599    info!(target: "stdout", "processing audio speech request");
600
601    #[cfg(feature = "logging")]
602    info!(target: "stdout", "Get the model instance.");
603    let graph = match PIPER_GRAPH.get() {
604        Some(graph) => graph,
605        None => {
606            let err_msg = "The PIPER_GRAPH is not initialized.";
607
608            #[cfg(feature = "logging")]
609            error!(target: "stdout", "{}", &err_msg);
610
611            return Err(LlamaCoreError::Operation(err_msg.to_owned()));
612        }
613    };
614
615    let mut graph = match graph.lock() {
616        Ok(graph) => graph,
617        Err(e) => {
618            let err_msg = format!("Failed to lock the graph. {}", e);
619
620            #[cfg(feature = "logging")]
621            error!(target: "stdout", "{}", &err_msg);
622
623            return Err(LlamaCoreError::Operation(err_msg));
624        }
625    };
626
627    // set the input tensor
628    #[cfg(feature = "logging")]
629    info!(target: "stdout", "Feed the text to the model.");
630    set_tensor_data(&mut graph, 0, request.input.as_bytes(), [1])?;
631
632    // compute the graph
633    #[cfg(feature = "logging")]
634    info!(target: "stdout", "create audio.");
635    if let Err(e) = graph.compute() {
636        let err_msg = format!("Failed to compute the graph. {}", e);
637
638        #[cfg(feature = "logging")]
639        error!(target: "stdout", "{}", &err_msg);
640
641        return Err(LlamaCoreError::Operation(err_msg));
642    }
643
644    // get the output tensor
645    #[cfg(feature = "logging")]
646    info!(target: "stdout", "[INFO] Retrieve the audio.");
647
648    let mut output_buffer = vec![0u8; MAX_BUFFER_SIZE];
649    let output_size = graph.get_output(0, &mut output_buffer).map_err(|e| {
650        let err_msg = format!("Failed to get the output tensor. {}", e);
651
652        #[cfg(feature = "logging")]
653        error!(target: "stdout", "{}", &err_msg);
654
655        LlamaCoreError::Operation(err_msg)
656    })?;
657
658    #[cfg(feature = "logging")]
659    info!(target: "stdout", "Output buffer size: {}", output_size);
660
661    Ok(output_buffer)
662}
663
664#[cfg(feature = "whisper")]
665fn reset_model_metadata() -> Result<(), LlamaCoreError> {
666    #[cfg(feature = "logging")]
667    debug!(target: "stdout", "Get the original metadata.");
668
669    // get metadata
670    let metadata = get_model_metadata()?;
671
672    #[cfg(feature = "logging")]
673    debug!(target: "stdout", "Set the original metadata to the model.");
674
675    #[cfg(feature = "logging")]
676    debug!(target: "stdout", "original metadata: {}", serde_json::to_string(&metadata).unwrap());
677
678    // update model with the original metadata
679    update_model_metadata(&metadata)
680}
681
682/// Get a copy of the metadata of the model.
683#[cfg(feature = "whisper")]
684fn get_model_metadata() -> Result<WhisperMetadata, LlamaCoreError> {
685    let audio_graph = match AUDIO_GRAPH.get() {
686        Some(audio_graph) => audio_graph,
687        None => {
688            let err_msg = "Fail to get the underlying value of `AUDIO_GRAPH`.";
689
690            #[cfg(feature = "logging")]
691            error!(target: "stdout", "{}", err_msg);
692
693            return Err(LlamaCoreError::Operation(err_msg.into()));
694        }
695    };
696
697    let audio_graph = audio_graph.lock().map_err(|e| {
698        let err_msg = format!("Fail to acquire the lock of `AUDIO_GRAPH`. {}", e);
699
700        #[cfg(feature = "logging")]
701        error!(target: "stdout", "{}", &err_msg);
702
703        LlamaCoreError::Operation(err_msg)
704    })?;
705
706    Ok(audio_graph.metadata.clone())
707}
708
709#[cfg(feature = "whisper")]
710fn update_model_metadata(metadata: &WhisperMetadata) -> Result<(), LlamaCoreError> {
711    let config = match serde_json::to_string(metadata) {
712        Ok(config) => config,
713        Err(e) => {
714            let err_msg = format!("Fail to serialize metadata to a JSON string. {}", e);
715
716            #[cfg(feature = "logging")]
717            error!(target: "stdout", "{}", &err_msg);
718
719            return Err(LlamaCoreError::Operation(err_msg));
720        }
721    };
722
723    let audio_graph = match AUDIO_GRAPH.get() {
724        Some(audio_graph) => audio_graph,
725        None => {
726            let err_msg = "Fail to get the underlying value of `AUDIO_GRAPH`.";
727
728            #[cfg(feature = "logging")]
729            error!(target: "stdout", "{}", err_msg);
730
731            return Err(LlamaCoreError::Operation(err_msg.into()));
732        }
733    };
734
735    let mut audio_graph = audio_graph.lock().map_err(|e| {
736        let err_msg = format!("Fail to acquire the lock of `AUDIO_GRAPH`. Reason: {}", e);
737
738        #[cfg(feature = "logging")]
739        error!(target: "stdout", "{}", &err_msg);
740
741        LlamaCoreError::Operation(err_msg)
742    })?;
743
744    // update metadata
745    set_tensor_data::<u8, WhisperMetadata>(&mut audio_graph, 1, config.as_bytes(), [1])
746}