1use crate::{error::LlamaCoreError, utils::set_tensor_data, MAX_BUFFER_SIZE, PIPER_GRAPH};
2#[cfg(feature = "whisper")]
3use crate::{metadata::whisper::WhisperMetadata, AUDIO_GRAPH};
4use endpoints::audio::speech::SpeechRequest;
5#[cfg(feature = "whisper")]
6use endpoints::audio::{
7 transcription::{TranscriptionObject, TranscriptionRequest},
8 translation::{TranslationObject, TranslationRequest},
9};
10#[cfg(feature = "whisper")]
11use std::path::Path;
12
13#[cfg(feature = "whisper")]
15#[cfg_attr(docsrs, doc(cfg(feature = "whisper")))]
16pub async fn audio_transcriptions(
17 request: TranscriptionRequest,
18) -> Result<TranscriptionObject, LlamaCoreError> {
19 let res = transcribe_audio(request).await;
20
21 #[cfg(feature = "logging")]
22 info!(target: "stdout", "Reset the model metadata.");
23
24 reset_model_metadata()?;
26
27 res
28}
29
30#[cfg(feature = "whisper")]
31async fn transcribe_audio(
32 request: TranscriptionRequest,
33) -> Result<TranscriptionObject, LlamaCoreError> {
34 #[cfg(feature = "logging")]
35 info!(target: "stdout", "processing audio transcription request");
36
37 let graph = match AUDIO_GRAPH.get() {
38 Some(graph) => graph,
39 None => {
40 let err_msg = "The AUDIO_GRAPH is not initialized.";
41
42 #[cfg(feature = "logging")]
43 error!(target: "stdout", "{}", &err_msg);
44
45 return Err(LlamaCoreError::Operation(err_msg.to_owned()));
46 }
47 };
48
49 let mut graph = match graph.lock() {
50 Ok(graph) => graph,
51 Err(e) => {
52 let err_msg = format!("Failed to lock the graph. {}", e);
53
54 #[cfg(feature = "logging")]
55 error!(target: "stdout", "{}", &err_msg);
56
57 return Err(LlamaCoreError::Operation(err_msg));
58 }
59 };
60
61 {
63 let mut should_update = false;
64
65 let mut metadata = graph.metadata.clone();
66
67 #[cfg(feature = "logging")]
68 info!(target: "stdout", "current metadata: {:?}", &metadata);
69
70 #[cfg(feature = "logging")]
71 info!(target: "stdout", "Check model metadata.");
72
73 if metadata.translate {
75 metadata.translate = false;
77
78 if !should_update {
79 should_update = true;
80 }
81 }
82
83 if let Some(language) = &request.language {
85 if *language != metadata.language {
86 metadata.language = language.clone();
88
89 if !should_update {
90 should_update = true;
91 }
92 }
93 }
94
95 if let Some(detect_language) = &request.detect_language {
97 if *detect_language != metadata.detect_language {
98 metadata.detect_language = *detect_language;
100
101 if !should_update {
102 should_update = true;
103 }
104 }
105 }
106
107 if let Some(offset_time) = &request.offset_time {
109 if *offset_time != metadata.offset_time {
110 metadata.offset_time = *offset_time;
112
113 if !should_update {
114 should_update = true;
115 }
116 }
117 }
118
119 if let Some(duration) = &request.duration {
121 if *duration != metadata.duration {
122 metadata.duration = *duration;
124
125 if !should_update {
126 should_update = true;
127 }
128 }
129 }
130
131 if let Some(max_context) = &request.max_context {
133 if *max_context != metadata.max_context {
134 metadata.max_context = *max_context;
136
137 if !should_update {
138 should_update = true;
139 }
140 }
141 }
142
143 if let Some(max_len) = &request.max_len {
145 if *max_len != metadata.max_len {
146 metadata.max_len = *max_len;
148
149 if !should_update {
150 should_update = true;
151 }
152 }
153 }
154
155 if let Some(temperature) = &request.temperature {
157 if *temperature != metadata.temperature {
158 metadata.temperature = *temperature;
160
161 if !should_update {
162 should_update = true;
163 }
164 }
165 }
166
167 if let Some(split_on_word) = &request.split_on_word {
169 if *split_on_word != metadata.split_on_word {
170 metadata.split_on_word = *split_on_word;
172
173 if !should_update {
174 should_update = true;
175 }
176 }
177 }
178
179 if let Some(prompt) = &request.prompt {
181 if !prompt.is_empty() {
182 match &metadata.prompt {
183 Some(p) => {
184 if *p != *prompt {
185 metadata.prompt = Some(prompt.clone());
186
187 if !should_update {
188 should_update = true;
189 }
190 }
191 }
192 None => {
193 metadata.prompt = Some(prompt.clone());
194 if !should_update {
195 should_update = true;
196 }
197 }
198 }
199 }
200 }
201
202 if should_update {
203 #[cfg(feature = "logging")]
204 info!(target: "stdout", "Set the metadata to the model.");
205
206 #[cfg(feature = "logging")]
207 debug!(target: "stdout", "new metadata: {}", serde_json::to_string(&metadata).unwrap());
208
209 match serde_json::to_string(&metadata) {
210 Ok(config) => {
211 set_tensor_data(&mut graph, 1, config.as_bytes(), [1])?;
213
214 #[cfg(feature = "logging")]
215 info!(target: "stdout", "metadata updated");
216 }
217 Err(e) => {
218 let err_msg = format!("Fail to serialize metadata to a JSON string. {}", e);
219
220 #[cfg(feature = "logging")]
221 error!(target: "stdout", "{}", &err_msg);
222
223 return Err(LlamaCoreError::Operation(err_msg));
224 }
225 };
226 }
227 }
228
229 let path = Path::new("archives")
230 .join(&request.file.id)
231 .join(&request.file.filename);
232
233 #[cfg(feature = "logging")]
234 info!(target: "stdout", "audio file path: {:?}", &path);
235
236 let wav_buf = load_audio_waveform(path)?;
238
239 #[cfg(feature = "logging")]
240 info!(target: "stdout", "read input tensor, size in bytes: {}", wav_buf.len());
241
242 #[cfg(feature = "logging")]
244 info!(target: "stdout", "Feed the audio data to the model.");
245 set_tensor_data(&mut graph, 0, &wav_buf, [1, wav_buf.len()])?;
246
247 #[cfg(feature = "logging")]
249 info!(target: "stdout", "Transcribe audio to text.");
250 if let Err(e) = graph.compute() {
251 let err_msg = format!("Failed to compute the graph. {}", e);
252
253 #[cfg(feature = "logging")]
254 error!(target: "stdout", "{}", &err_msg);
255
256 return Err(LlamaCoreError::Operation(err_msg));
257 }
258
259 #[cfg(feature = "logging")]
261 info!(target: "stdout", "[INFO] Retrieve the transcription data.");
262
263 let mut output_buffer = vec![0u8; MAX_BUFFER_SIZE];
265 let output_size = graph.get_output(0, &mut output_buffer).map_err(|e| {
266 let err_msg = format!("Failed to get the output tensor. {}", e);
267
268 #[cfg(feature = "logging")]
269 error!(target: "stdout", "{}", &err_msg);
270
271 LlamaCoreError::Operation(err_msg)
272 })?;
273
274 #[cfg(feature = "logging")]
275 info!(target: "stdout", "Output buffer size: {}", output_size);
276
277 #[cfg(feature = "logging")]
279 info!(target: "stdout", "Decode the transcription data to plain text.");
280
281 let text = String::from_utf8_lossy(&output_buffer[..output_size]);
282
283 #[cfg(feature = "logging")]
284 info!(target: "stdout", "raw transcription text:\n{}", &text);
285
286 let obj = TranscriptionObject {
287 text: text.trim().to_owned(),
288 };
289
290 #[cfg(feature = "logging")]
291 info!(target: "stdout", "End of the audio transcription.");
292
293 Ok(obj)
294}
295
296#[cfg(feature = "whisper")]
297fn load_audio_waveform(filename: impl AsRef<std::path::Path>) -> Result<Vec<u8>, LlamaCoreError> {
298 std::fs::read(filename)
299 .map_err(|e| {
300 let err_msg = format!("Failed to read the input tensor. {}", e);
301
302 #[cfg(feature = "logging")]
303 error!(target: "stdout", "{}", &err_msg);
304
305 LlamaCoreError::Operation(err_msg)
306 })
307 .map_err(|e| LlamaCoreError::Operation(e.to_string()))
308}
309
310fn _remove_blank_audio(input: &str) -> String {
311 let blank_audio_marker = "[BLANK_AUDIO]";
312
313 let filtered_segments: Vec<&str> = input
315 .lines()
316 .filter(|segment| !segment.contains(blank_audio_marker))
317 .collect();
318
319 filtered_segments.join("\n")
321}
322
323#[cfg(feature = "whisper")]
325#[cfg_attr(docsrs, doc(cfg(feature = "whisper")))]
326pub async fn audio_translations(
327 request: TranslationRequest,
328) -> Result<TranslationObject, LlamaCoreError> {
329 let res = translate_audio(request).await;
330
331 #[cfg(feature = "logging")]
332 info!(target: "stdout", "Reset the model metadata.");
333
334 reset_model_metadata()?;
336
337 res
338}
339
340#[cfg(feature = "whisper")]
341async fn translate_audio(request: TranslationRequest) -> Result<TranslationObject, LlamaCoreError> {
342 #[cfg(feature = "logging")]
343 info!(target: "stdout", "processing audio translation request");
344
345 let graph = match AUDIO_GRAPH.get() {
346 Some(graph) => graph,
347 None => {
348 let err_msg = "The AUDIO_GRAPH is not initialized.";
349
350 #[cfg(feature = "logging")]
351 error!(target: "stdout", "{}", &err_msg);
352
353 return Err(LlamaCoreError::Operation(err_msg.to_owned()));
354 }
355 };
356
357 let mut graph = match graph.lock() {
358 Ok(graph) => graph,
359 Err(e) => {
360 let err_msg = format!("Failed to lock the graph. {}", e);
361
362 #[cfg(feature = "logging")]
363 error!(target: "stdout", "{}", &err_msg);
364
365 return Err(LlamaCoreError::Operation(err_msg));
366 }
367 };
368
369 {
371 let mut should_update = false;
372
373 let mut metadata = graph.metadata.clone();
374
375 #[cfg(feature = "logging")]
376 info!(target: "stdout", "current metadata: {:?}", &metadata);
377
378 #[cfg(feature = "logging")]
379 info!(target: "stdout", "Check model metadata.");
380
381 if !metadata.translate {
383 metadata.translate = true;
384
385 if !should_update {
386 should_update = true;
387 }
388 }
389
390 if let Some(language) = &request.language {
392 if *language != metadata.language {
393 metadata.language = language.clone();
394
395 if !should_update {
396 should_update = true;
397 }
398 }
399 }
400
401 if let Some(detect_language) = &request.detect_language {
403 if *detect_language != metadata.detect_language {
404 metadata.detect_language = *detect_language;
405
406 if !should_update {
407 should_update = true;
408 }
409 }
410 }
411
412 if let Some(offset_time) = &request.offset_time {
414 if *offset_time != metadata.offset_time {
415 metadata.offset_time = *offset_time;
417
418 if !should_update {
419 should_update = true;
420 }
421 }
422 }
423
424 if let Some(duration) = &request.duration {
426 if *duration != metadata.duration {
427 metadata.duration = *duration;
428
429 if !should_update {
430 should_update = true;
431 }
432 }
433 }
434
435 if let Some(max_context) = &request.max_context {
437 if *max_context != metadata.max_context {
438 metadata.max_context = *max_context;
439
440 if !should_update {
441 should_update = true;
442 }
443 }
444 }
445
446 if let Some(max_len) = &request.max_len {
448 if *max_len != metadata.max_len {
449 metadata.max_len = *max_len;
450
451 if !should_update {
452 should_update = true;
453 }
454 }
455 }
456
457 if let Some(temperature) = &request.temperature {
459 if *temperature != metadata.temperature {
460 metadata.temperature = *temperature;
461
462 if !should_update {
463 should_update = true;
464 }
465 }
466 }
467
468 if let Some(split_on_word) = &request.split_on_word {
470 if *split_on_word != metadata.split_on_word {
471 metadata.split_on_word = *split_on_word;
472
473 if !should_update {
474 should_update = true;
475 }
476 }
477 }
478
479 if let Some(prompt) = &request.prompt {
481 if !prompt.is_empty() {
482 match &metadata.prompt {
483 Some(p) => {
484 if *p != *prompt {
485 metadata.prompt = Some(prompt.clone());
486
487 if !should_update {
488 should_update = true;
489 }
490 }
491 }
492 None => {
493 metadata.prompt = Some(prompt.clone());
494 if !should_update {
495 should_update = true;
496 }
497 }
498 }
499 }
500 }
501
502 if should_update {
503 #[cfg(feature = "logging")]
504 info!(target: "stdout", "Set the metadata to the model.");
505
506 #[cfg(feature = "logging")]
507 debug!(target: "stdout", "new metadata: {}", serde_json::to_string(&metadata).unwrap());
508
509 match serde_json::to_string(&metadata) {
510 Ok(config) => {
511 set_tensor_data(&mut graph, 1, config.as_bytes(), [1])?;
513 }
514 Err(e) => {
515 let err_msg = format!("Fail to serialize metadata to a JSON string. {}", e);
516
517 #[cfg(feature = "logging")]
518 error!(target: "stdout", "{}", &err_msg);
519
520 return Err(LlamaCoreError::Operation(err_msg));
521 }
522 };
523 }
524 }
525
526 let path = Path::new("archives")
527 .join(&request.file.id)
528 .join(&request.file.filename);
529
530 #[cfg(feature = "logging")]
531 info!(target: "stdout", "audio file path: {:?}", &path);
532
533 let wav_buf = load_audio_waveform(path)?;
535
536 #[cfg(feature = "logging")]
537 info!(target: "stdout", "read input tensor, size in bytes: {}", wav_buf.len());
538
539 #[cfg(feature = "logging")]
541 info!(target: "stdout", "feed the audio data to the model.");
542 set_tensor_data(&mut graph, 0, &wav_buf, [1, wav_buf.len()])?;
543
544 #[cfg(feature = "logging")]
546 info!(target: "stdout", "translate audio to text.");
547 if let Err(e) = graph.compute() {
548 let err_msg = format!("Failed to compute the graph. {}", e);
549
550 #[cfg(feature = "logging")]
551 error!(target: "stdout", "{}", &err_msg);
552
553 return Err(LlamaCoreError::Operation(err_msg));
554 }
555
556 #[cfg(feature = "logging")]
558 info!(target: "stdout", "[INFO] retrieve the translation data.");
559
560 let mut output_buffer = vec![0u8; MAX_BUFFER_SIZE];
562 let output_size = graph.get_output(0, &mut output_buffer).map_err(|e| {
563 let err_msg = format!("Failed to get the output tensor. {}", e);
564
565 #[cfg(feature = "logging")]
566 error!(target: "stdout", "{}", &err_msg);
567
568 LlamaCoreError::Operation(err_msg)
569 })?;
570
571 #[cfg(feature = "logging")]
572 info!(target: "stdout", "output buffer size: {}", output_size);
573
574 #[cfg(feature = "logging")]
576 info!(target: "stdout", "decode the translation data to plain text.");
577
578 let text = String::from_utf8_lossy(&output_buffer[..output_size]);
579
580 #[cfg(feature = "logging")]
581 info!(target: "stdout", "raw translation text:\n{}", &text);
582
583 let obj = TranslationObject {
584 text: text.trim().to_owned(),
585 };
586
587 #[cfg(feature = "logging")]
588 info!(target: "stdout", "End of the audio translation.");
589
590 #[cfg(feature = "logging")]
591 info!(target: "stdout", "Reset the model metadata.");
592
593 Ok(obj)
594}
595
596pub async fn create_speech(request: SpeechRequest) -> Result<Vec<u8>, LlamaCoreError> {
598 #[cfg(feature = "logging")]
599 info!(target: "stdout", "processing audio speech request");
600
601 #[cfg(feature = "logging")]
602 info!(target: "stdout", "Get the model instance.");
603 let graph = match PIPER_GRAPH.get() {
604 Some(graph) => graph,
605 None => {
606 let err_msg = "The PIPER_GRAPH is not initialized.";
607
608 #[cfg(feature = "logging")]
609 error!(target: "stdout", "{}", &err_msg);
610
611 return Err(LlamaCoreError::Operation(err_msg.to_owned()));
612 }
613 };
614
615 let mut graph = match graph.lock() {
616 Ok(graph) => graph,
617 Err(e) => {
618 let err_msg = format!("Failed to lock the graph. {}", e);
619
620 #[cfg(feature = "logging")]
621 error!(target: "stdout", "{}", &err_msg);
622
623 return Err(LlamaCoreError::Operation(err_msg));
624 }
625 };
626
627 #[cfg(feature = "logging")]
629 info!(target: "stdout", "Feed the text to the model.");
630 set_tensor_data(&mut graph, 0, request.input.as_bytes(), [1])?;
631
632 #[cfg(feature = "logging")]
634 info!(target: "stdout", "create audio.");
635 if let Err(e) = graph.compute() {
636 let err_msg = format!("Failed to compute the graph. {}", e);
637
638 #[cfg(feature = "logging")]
639 error!(target: "stdout", "{}", &err_msg);
640
641 return Err(LlamaCoreError::Operation(err_msg));
642 }
643
644 #[cfg(feature = "logging")]
646 info!(target: "stdout", "[INFO] Retrieve the audio.");
647
648 let mut output_buffer = vec![0u8; MAX_BUFFER_SIZE];
649 let output_size = graph.get_output(0, &mut output_buffer).map_err(|e| {
650 let err_msg = format!("Failed to get the output tensor. {}", e);
651
652 #[cfg(feature = "logging")]
653 error!(target: "stdout", "{}", &err_msg);
654
655 LlamaCoreError::Operation(err_msg)
656 })?;
657
658 #[cfg(feature = "logging")]
659 info!(target: "stdout", "Output buffer size: {}", output_size);
660
661 Ok(output_buffer)
662}
663
664#[cfg(feature = "whisper")]
665fn reset_model_metadata() -> Result<(), LlamaCoreError> {
666 #[cfg(feature = "logging")]
667 debug!(target: "stdout", "Get the original metadata.");
668
669 let metadata = get_model_metadata()?;
671
672 #[cfg(feature = "logging")]
673 debug!(target: "stdout", "Set the original metadata to the model.");
674
675 #[cfg(feature = "logging")]
676 debug!(target: "stdout", "original metadata: {}", serde_json::to_string(&metadata).unwrap());
677
678 update_model_metadata(&metadata)
680}
681
682#[cfg(feature = "whisper")]
684fn get_model_metadata() -> Result<WhisperMetadata, LlamaCoreError> {
685 let audio_graph = match AUDIO_GRAPH.get() {
686 Some(audio_graph) => audio_graph,
687 None => {
688 let err_msg = "Fail to get the underlying value of `AUDIO_GRAPH`.";
689
690 #[cfg(feature = "logging")]
691 error!(target: "stdout", "{}", err_msg);
692
693 return Err(LlamaCoreError::Operation(err_msg.into()));
694 }
695 };
696
697 let audio_graph = audio_graph.lock().map_err(|e| {
698 let err_msg = format!("Fail to acquire the lock of `AUDIO_GRAPH`. {}", e);
699
700 #[cfg(feature = "logging")]
701 error!(target: "stdout", "{}", &err_msg);
702
703 LlamaCoreError::Operation(err_msg)
704 })?;
705
706 Ok(audio_graph.metadata.clone())
707}
708
709#[cfg(feature = "whisper")]
710fn update_model_metadata(metadata: &WhisperMetadata) -> Result<(), LlamaCoreError> {
711 let config = match serde_json::to_string(metadata) {
712 Ok(config) => config,
713 Err(e) => {
714 let err_msg = format!("Fail to serialize metadata to a JSON string. {}", e);
715
716 #[cfg(feature = "logging")]
717 error!(target: "stdout", "{}", &err_msg);
718
719 return Err(LlamaCoreError::Operation(err_msg));
720 }
721 };
722
723 let audio_graph = match AUDIO_GRAPH.get() {
724 Some(audio_graph) => audio_graph,
725 None => {
726 let err_msg = "Fail to get the underlying value of `AUDIO_GRAPH`.";
727
728 #[cfg(feature = "logging")]
729 error!(target: "stdout", "{}", err_msg);
730
731 return Err(LlamaCoreError::Operation(err_msg.into()));
732 }
733 };
734
735 let mut audio_graph = audio_graph.lock().map_err(|e| {
736 let err_msg = format!("Fail to acquire the lock of `AUDIO_GRAPH`. Reason: {}", e);
737
738 #[cfg(feature = "logging")]
739 error!(target: "stdout", "{}", &err_msg);
740
741 LlamaCoreError::Operation(err_msg)
742 })?;
743
744 set_tensor_data::<u8, WhisperMetadata>(&mut audio_graph, 1, config.as_bytes(), [1])
746}