1use crate::{
2 error::LlamaCoreError,
3 metadata::ggml::GgmlTtsMetadata,
4 running_mode,
5 utils::{set_tensor_data, set_tensor_data_u8},
6 Graph, RunningMode, MAX_BUFFER_SIZE, TTS_GRAPHS,
7};
8use endpoints::audio::speech::SpeechRequest;
9
10pub async fn create_speech(request: SpeechRequest) -> Result<Vec<u8>, LlamaCoreError> {
12 #[cfg(feature = "logging")]
13 info!(target: "stdout", "processing audio speech request");
14
15 let running_mode = running_mode()?;
16 if !running_mode.contains(RunningMode::TTS) {
17 let err_msg = "Generating audio speech is only supported in the tts mode.";
18
19 #[cfg(feature = "logging")]
20 error!(target: "stdout", "{}", err_msg);
21
22 return Err(LlamaCoreError::Operation(err_msg.into()));
23 }
24
25 let model_name = &request.model;
26
27 let res = {
28 let tts_graphs = match TTS_GRAPHS.get() {
29 Some(tts_graphs) => tts_graphs,
30 None => {
31 let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
32
33 #[cfg(feature = "logging")]
34 error!(target: "stdout", "{}", &err_msg);
35
36 return Err(LlamaCoreError::Operation(err_msg.into()));
37 }
38 };
39
40 let mut tts_graphs = tts_graphs.lock().map_err(|e| {
41 let err_msg = format!("Fail to acquire the lock of `TTS_GRAPHS`. {}", e);
42
43 #[cfg(feature = "logging")]
44 error!(target: "stdout", "{}", &err_msg);
45
46 LlamaCoreError::Operation(err_msg)
47 })?;
48
49 match tts_graphs.contains_key(model_name) {
50 true => {
51 let graph = tts_graphs.get_mut(model_name).unwrap();
52
53 compute_by_graph(graph, &request)
54 }
55 false => match tts_graphs.iter_mut().next() {
56 Some((_name, graph)) => compute_by_graph(graph, &request),
57 None => {
58 let err_msg = "There is no model available in the chat graphs.";
59
60 #[cfg(feature = "logging")]
61 error!(target: "stdout", "{}", &err_msg);
62
63 Err(LlamaCoreError::Operation(err_msg.into()))
64 }
65 },
66 }
67 };
68
69 #[cfg(feature = "logging")]
70 info!(target: "stdout", "Reset the model metadata");
71
72 reset_model_metadata(Some(model_name))?;
74
75 res
76}
77
78fn compute_by_graph(
79 graph: &mut Graph<GgmlTtsMetadata>,
80 request: &SpeechRequest,
81) -> Result<Vec<u8>, LlamaCoreError> {
82 #[cfg(feature = "logging")]
83 info!(target: "stdout", "Input text: {}", &request.input);
84
85 #[cfg(feature = "logging")]
87 info!(target: "stdout", "Feed the text to the model.");
88 set_tensor_data(graph, 0, request.input.as_bytes(), [1])?;
89
90 #[cfg(feature = "logging")]
92 info!(target: "stdout", "Generate audio");
93 if let Err(e) = graph.compute() {
94 let err_msg = format!("Failed to compute the graph. {}", e);
95
96 #[cfg(feature = "logging")]
97 error!(target: "stdout", "{}", &err_msg);
98
99 return Err(LlamaCoreError::Operation(err_msg));
100 }
101
102 #[cfg(feature = "logging")]
104 info!(target: "stdout", "[INFO] Retrieve the audio.");
105
106 let mut output_buffer = vec![0u8; MAX_BUFFER_SIZE];
107 let output_size = graph.get_output(0, &mut output_buffer).map_err(|e| {
108 let err_msg = format!("Failed to get the output tensor. {}", e);
109
110 #[cfg(feature = "logging")]
111 error!(target: "stdout", "{}", &err_msg);
112
113 LlamaCoreError::Operation(err_msg)
114 })?;
115
116 #[cfg(feature = "logging")]
117 info!(target: "stdout", "Output buffer size: {}", output_size);
118
119 Ok(output_buffer)
120}
121
122fn get_model_metadata(model_name: Option<&String>) -> Result<GgmlTtsMetadata, LlamaCoreError> {
124 let tts_graphs = match TTS_GRAPHS.get() {
125 Some(tts_graphs) => tts_graphs,
126 None => {
127 let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
128
129 #[cfg(feature = "logging")]
130 error!(target: "stdout", "{}", err_msg);
131
132 return Err(LlamaCoreError::Operation(err_msg.into()));
133 }
134 };
135
136 let tts_graphs = tts_graphs.lock().map_err(|e| {
137 let err_msg = format!("Fail to acquire the lock of `TTS_GRAPHS`. {}", e);
138
139 #[cfg(feature = "logging")]
140 error!(target: "stdout", "{}", &err_msg);
141
142 LlamaCoreError::Operation(err_msg)
143 })?;
144
145 match model_name {
146 Some(model_name) => match tts_graphs.contains_key(model_name) {
147 true => {
148 let graph = tts_graphs.get(model_name).unwrap();
149 Ok(graph.metadata.clone())
150 }
151 false => match tts_graphs.iter().next() {
152 Some((_, graph)) => Ok(graph.metadata.clone()),
153 None => {
154 let err_msg = "There is no model available in the tts graphs.";
155
156 #[cfg(feature = "logging")]
157 error!(target: "stdout", "{}", &err_msg);
158
159 Err(LlamaCoreError::Operation(err_msg.into()))
160 }
161 },
162 },
163 None => match tts_graphs.iter().next() {
164 Some((_, graph)) => Ok(graph.metadata.clone()),
165 None => {
166 let err_msg = "There is no model available in the tts graphs.";
167
168 #[cfg(feature = "logging")]
169 error!(target: "stdout", "{}", err_msg);
170
171 Err(LlamaCoreError::Operation(err_msg.into()))
172 }
173 },
174 }
175}
176
177fn update_model_metadata(
178 model_name: Option<&String>,
179 metadata: &GgmlTtsMetadata,
180) -> Result<(), LlamaCoreError> {
181 let config = match serde_json::to_string(metadata) {
182 Ok(config) => config,
183 Err(e) => {
184 let err_msg = format!("Fail to serialize metadata to a JSON string. {}", e);
185
186 #[cfg(feature = "logging")]
187 error!(target: "stdout", "{}", &err_msg);
188
189 return Err(LlamaCoreError::Operation(err_msg));
190 }
191 };
192
193 let tts_graphs = match TTS_GRAPHS.get() {
194 Some(tts_graphs) => tts_graphs,
195 None => {
196 let err_msg = "Fail to get the underlying value of `TTS_GRAPHS`.";
197
198 #[cfg(feature = "logging")]
199 error!(target: "stdout", "{}", err_msg);
200
201 return Err(LlamaCoreError::Operation(err_msg.into()));
202 }
203 };
204
205 let mut tts_graphs = tts_graphs.lock().map_err(|e| {
206 let err_msg = format!("Fail to acquire the lock of `TTS_GRAPHS`. Reason: {}", e);
207
208 #[cfg(feature = "logging")]
209 error!(target: "stdout", "{}", &err_msg);
210
211 LlamaCoreError::Operation(err_msg)
212 })?;
213
214 match model_name {
215 Some(model_name) => {
216 match tts_graphs.contains_key(model_name) {
217 true => {
218 let graph = tts_graphs.get_mut(model_name).unwrap();
219 set_tensor_data_u8(graph, 1, config.as_bytes())
221 }
222 false => match tts_graphs.iter_mut().next() {
223 Some((_, graph)) => {
224 set_tensor_data_u8(graph, 1, config.as_bytes())
226 }
227 None => {
228 let err_msg = "There is no model available in the tts graphs.";
229
230 #[cfg(feature = "logging")]
231 error!(target: "stdout", "{}", &err_msg);
232
233 Err(LlamaCoreError::Operation(err_msg.into()))
234 }
235 },
236 }
237 }
238 None => {
239 match tts_graphs.iter_mut().next() {
240 Some((_, graph)) => {
241 set_tensor_data_u8(graph, 1, config.as_bytes())
243 }
244 None => {
245 let err_msg = "There is no model available in the tts graphs.";
246
247 #[cfg(feature = "logging")]
248 error!(target: "stdout", "{}", err_msg);
249
250 Err(LlamaCoreError::Operation(err_msg.into()))
251 }
252 }
253 }
254 }
255}
256
257fn reset_model_metadata(model_name: Option<&String>) -> Result<(), LlamaCoreError> {
258 let metadata = get_model_metadata(model_name)?;
260
261 update_model_metadata(model_name, &metadata)
263}