llama_core/
tts.rs

1use crate::{
2    error::LlamaCoreError,
3    metadata::ggml::GgmlTtsMetadata,
4    running_mode,
5    utils::{set_tensor_data, set_tensor_data_u8},
6    Graph, RunningMode, MAX_BUFFER_SIZE, TTS_GRAPHS,
7};
8use endpoints::audio::speech::SpeechRequest;
9
10/// Generate audio from the input text.
11pub async fn create_speech(request: SpeechRequest) -> Result<Vec<u8>, LlamaCoreError> {
12    #[cfg(feature = "logging")]
13    info!(target: "stdout", "processing audio speech request");
14
15    let running_mode = running_mode()?;
16    if !running_mode.contains(RunningMode::TTS) {
17        let err_msg = "Generating audio speech is only supported in the tts mode.";
18
19        #[cfg(feature = "logging")]
20        error!(target: "stdout", "{}", err_msg);
21
22        return Err(LlamaCoreError::Operation(err_msg.into()));
23    }
24
25    let model_name = &request.model;
26
27    let res = {
28        let tts_graphs = match TTS_GRAPHS.get() {
29            Some(tts_graphs) => tts_graphs,
30            None => {
31                let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
32
33                #[cfg(feature = "logging")]
34                error!(target: "stdout", "{}", &err_msg);
35
36                return Err(LlamaCoreError::Operation(err_msg.into()));
37            }
38        };
39
40        let mut tts_graphs = tts_graphs.lock().map_err(|e| {
41            let err_msg = format!("Fail to acquire the lock of `TTS_GRAPHS`. {}", e);
42
43            #[cfg(feature = "logging")]
44            error!(target: "stdout", "{}", &err_msg);
45
46            LlamaCoreError::Operation(err_msg)
47        })?;
48
49        match tts_graphs.contains_key(model_name) {
50            true => {
51                let graph = tts_graphs.get_mut(model_name).unwrap();
52
53                compute_by_graph(graph, &request)
54            }
55            false => match tts_graphs.iter_mut().next() {
56                Some((_name, graph)) => compute_by_graph(graph, &request),
57                None => {
58                    let err_msg = "There is no model available in the chat graphs.";
59
60                    #[cfg(feature = "logging")]
61                    error!(target: "stdout", "{}", &err_msg);
62
63                    Err(LlamaCoreError::Operation(err_msg.into()))
64                }
65            },
66        }
67    };
68
69    #[cfg(feature = "logging")]
70    info!(target: "stdout", "Reset the model metadata");
71
72    // reset the model metadata
73    reset_model_metadata(Some(model_name))?;
74
75    res
76}
77
78fn compute_by_graph(
79    graph: &mut Graph<GgmlTtsMetadata>,
80    request: &SpeechRequest,
81) -> Result<Vec<u8>, LlamaCoreError> {
82    #[cfg(feature = "logging")]
83    info!(target: "stdout", "Input text: {}", &request.input);
84
85    // set the input tensor
86    #[cfg(feature = "logging")]
87    info!(target: "stdout", "Feed the text to the model.");
88    set_tensor_data(graph, 0, request.input.as_bytes(), [1])?;
89
90    // compute the graph
91    #[cfg(feature = "logging")]
92    info!(target: "stdout", "Generate audio");
93    if let Err(e) = graph.compute() {
94        let err_msg = format!("Failed to compute the graph. {}", e);
95
96        #[cfg(feature = "logging")]
97        error!(target: "stdout", "{}", &err_msg);
98
99        return Err(LlamaCoreError::Operation(err_msg));
100    }
101
102    // get the output tensor
103    #[cfg(feature = "logging")]
104    info!(target: "stdout", "[INFO] Retrieve the audio.");
105
106    let mut output_buffer = vec![0u8; MAX_BUFFER_SIZE];
107    let output_size = graph.get_output(0, &mut output_buffer).map_err(|e| {
108        let err_msg = format!("Failed to get the output tensor. {}", e);
109
110        #[cfg(feature = "logging")]
111        error!(target: "stdout", "{}", &err_msg);
112
113        LlamaCoreError::Operation(err_msg)
114    })?;
115
116    #[cfg(feature = "logging")]
117    info!(target: "stdout", "Output buffer size: {}", output_size);
118
119    Ok(output_buffer)
120}
121
122/// Get a copy of the metadata of the model.
123fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlTtsMetadata, LlamaCoreError> {
124    let tts_graphs = match TTS_GRAPHS.get() {
125        Some(tts_graphs) => tts_graphs,
126        None => {
127            let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
128
129            #[cfg(feature = "logging")]
130            error!(target: "stdout", "{}", err_msg);
131
132            return Err(LlamaCoreError::Operation(err_msg.into()));
133        }
134    };
135
136    let tts_graphs = tts_graphs.lock().map_err(|e| {
137        let err_msg = format!("Fail to acquire the lock of `TTS_GRAPHS`. {}", e);
138
139        #[cfg(feature = "logging")]
140        error!(target: "stdout", "{}", &err_msg);
141
142        LlamaCoreError::Operation(err_msg)
143    })?;
144
145    match model_name {
146        Some(model_name) => match tts_graphs.contains_key(model_name) {
147            true => {
148                let graph = tts_graphs.get(model_name).unwrap();
149                Ok(graph.metadata.clone())
150            }
151            false => match tts_graphs.iter().next() {
152                Some((_, graph)) => Ok(graph.metadata.clone()),
153                None => {
154                    let err_msg = "There is no model available in the tts graphs.";
155
156                    #[cfg(feature = "logging")]
157                    error!(target: "stdout", "{}", &err_msg);
158
159                    Err(LlamaCoreError::Operation(err_msg.into()))
160                }
161            },
162        },
163        None => match tts_graphs.iter().next() {
164            Some((_, graph)) => Ok(graph.metadata.clone()),
165            None => {
166                let err_msg = "There is no model available in the tts graphs.";
167
168                #[cfg(feature = "logging")]
169                error!(target: "stdout", "{}", err_msg);
170
171                Err(LlamaCoreError::Operation(err_msg.into()))
172            }
173        },
174    }
175}
176
177fn update_model_metadata(
178    model_name: Option<&String>,
179    metadata: &GgmlTtsMetadata,
180) -> Result<(), LlamaCoreError> {
181    let config = match serde_json::to_string(metadata) {
182        Ok(config) => config,
183        Err(e) => {
184            let err_msg = format!("Fail to serialize metadata to a JSON string. {}", e);
185
186            #[cfg(feature = "logging")]
187            error!(target: "stdout", "{}", &err_msg);
188
189            return Err(LlamaCoreError::Operation(err_msg));
190        }
191    };
192
193    let tts_graphs = match TTS_GRAPHS.get() {
194        Some(tts_graphs) => tts_graphs,
195        None => {
196            let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
197
198            #[cfg(feature = "logging")]
199            error!(target: "stdout", "{}", err_msg);
200
201            return Err(LlamaCoreError::Operation(err_msg.into()));
202        }
203    };
204
205    let mut tts_graphs = tts_graphs.lock().map_err(|e| {
206        let err_msg = format!("Fail to acquire the lock of `TTS_GRAPHS`. Reason: {}", e);
207
208        #[cfg(feature = "logging")]
209        error!(target: "stdout", "{}", &err_msg);
210
211        LlamaCoreError::Operation(err_msg)
212    })?;
213
214    match model_name {
215        Some(model_name) => {
216            match tts_graphs.contains_key(model_name) {
217                true => {
218                    let graph = tts_graphs.get_mut(model_name).unwrap();
219                    // update metadata
220                    set_tensor_data_u8(graph, 1, config.as_bytes())
221                }
222                false => match tts_graphs.iter_mut().next() {
223                    Some((_, graph)) => {
224                        // update metadata
225                        set_tensor_data_u8(graph, 1, config.as_bytes())
226                    }
227                    None => {
228                        let err_msg = "There is no model available in the tts graphs.";
229
230                        #[cfg(feature = "logging")]
231                        error!(target: "stdout", "{}", &err_msg);
232
233                        Err(LlamaCoreError::Operation(err_msg.into()))
234                    }
235                },
236            }
237        }
238        None => {
239            match tts_graphs.iter_mut().next() {
240                Some((_, graph)) => {
241                    // update metadata
242                    set_tensor_data_u8(graph, 1, config.as_bytes())
243                }
244                None => {
245                    let err_msg = "There is no model available in the tts graphs.";
246
247                    #[cfg(feature = "logging")]
248                    error!(target: "stdout", "{}", err_msg);
249
250                    Err(LlamaCoreError::Operation(err_msg.into()))
251                }
252            }
253        }
254    }
255}
256
257fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
258    // get metadata
259    let metadata = get_model_metadata(model_name)?;
260
261    // update model with the original metadata
262    update_model_metadata(model_name, &metadata)
263}