llama_core/metadata/
whisper.rs

1//! Define metadata for the whisper model.
2
3use super::BaseMetadata;
4use serde::{Deserialize, Serialize};
5use std::path::{Path, PathBuf};
6
7/// The sample rate of the audio input
8pub const WHISPER_SAMPLE_RATE: usize = 16000;
9
10/// Builder for creating an audio metadata
11#[derive(Debug)]
12pub struct WhisperMetadataBuilder {
13    metadata: WhisperMetadata,
14}
15impl WhisperMetadataBuilder {
16    pub fn new<S: Into<String>>(model_name: S, model_alias: S) -> Self {
17        let metadata = WhisperMetadata {
18            model_name: model_name.into(),
19            model_alias: model_alias.into(),
20            ..Default::default()
21        };
22
23        Self { metadata }
24    }
25
26    pub fn with_model_path(mut self, model_path: impl AsRef<Path>) -> Self {
27        self.metadata.model_path = model_path.as_ref().to_path_buf();
28        self
29    }
30
31    pub fn enable_plugin_log(mut self, enable: bool) -> Self {
32        self.metadata.log_enable = enable;
33        self
34    }
35
36    pub fn enable_debug_log(mut self, enable: bool) -> Self {
37        self.metadata.debug_log = enable;
38        self
39    }
40
41    pub fn with_threads(mut self, threads: u64) -> Self {
42        self.metadata.threads = threads;
43        self
44    }
45
46    pub fn enable_translate(mut self, enable: bool) -> Self {
47        self.metadata.translate = enable;
48        self
49    }
50
51    pub fn with_language(mut self, language: String) -> Self {
52        self.metadata.language = language;
53        self
54    }
55
56    pub fn with_processors(mut self, processors: u64) -> Self {
57        self.metadata.processors = processors;
58        self
59    }
60
61    pub fn with_offset_time(mut self, offset_t: u64) -> Self {
62        self.metadata.offset_time = offset_t;
63        self
64    }
65
66    pub fn with_duration(mut self, duration: u64) -> Self {
67        self.metadata.duration = duration;
68        self
69    }
70
71    pub fn with_max_context(mut self, max_context: i32) -> Self {
72        self.metadata.max_context = max_context;
73        self
74    }
75
76    pub fn with_max_len(mut self, max_len: u64) -> Self {
77        self.metadata.max_len = max_len;
78        self
79    }
80
81    pub fn split_on_word(mut self, split_on_word: bool) -> Self {
82        self.metadata.split_on_word = split_on_word;
83        self
84    }
85
86    pub fn output_txt(mut self, output_txt: bool) -> Self {
87        self.metadata.output_txt = output_txt;
88        self
89    }
90
91    pub fn output_vtt(mut self, output_vtt: bool) -> Self {
92        self.metadata.output_vtt = output_vtt;
93        self
94    }
95
96    pub fn output_srt(mut self, output_srt: bool) -> Self {
97        self.metadata.output_srt = output_srt;
98        self
99    }
100
101    pub fn output_lrc(mut self, output_lrc: bool) -> Self {
102        self.metadata.output_lrc = output_lrc;
103        self
104    }
105
106    pub fn output_csv(mut self, output_csv: bool) -> Self {
107        self.metadata.output_csv = output_csv;
108        self
109    }
110
111    pub fn output_json(mut self, output_json: bool) -> Self {
112        self.metadata.output_json = output_json;
113        self
114    }
115
116    pub fn with_temperature(mut self, temperature: f64) -> Self {
117        self.metadata.temperature = temperature;
118        self
119    }
120
121    pub fn detect_language(mut self, detect_language: bool) -> Self {
122        self.metadata.detect_language = detect_language;
123        self
124    }
125
126    pub fn with_prompt(mut self, prompt: String) -> Self {
127        if !prompt.is_empty() {
128            self.metadata.prompt = Some(prompt);
129        }
130        self
131    }
132
133    pub fn build(self) -> WhisperMetadata {
134        self.metadata
135    }
136}
137
138/// Metadata for whisper model
139#[derive(Debug, Clone, Deserialize, Serialize)]
140pub struct WhisperMetadata {
141    // this field not defined for the beckend plugin
142    #[serde(skip_serializing)]
143    pub model_name: String,
144    // this field not defined for the beckend plugin
145    #[serde(skip_serializing)]
146    pub model_alias: String,
147    // path to the model file
148    #[serde(skip_serializing)]
149    pub model_path: PathBuf,
150
151    #[serde(rename = "enable-log")]
152    pub log_enable: bool,
153    /// Enable debug mode. Defaults to false.
154    #[serde(rename = "enable-debug-log")]
155    pub debug_log: bool,
156
157    /// Number of threads to use during computation. Defaults to 4.
158    pub threads: u64,
159    /// Translate from source language to english. Defaults to false.
160    pub translate: bool,
161    /// The language of the input audio. `auto` for auto-detection. Defaults to `en`.
162    ///
163    /// Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will improve accuracy and latency.
164    pub language: String,
165    /// Number of processors to use during computation. Defaults to 1.
166    pub processors: u64,
167    /// Time offset in milliseconds. Defaults to 0.
168    #[serde(rename = "offset-t")]
169    pub offset_time: u64,
170    /// Duration of audio to process in milliseconds. Defaults to 0.
171    pub duration: u64,
172    /// Maximum number of text context tokens to store. Defaults to -1.
173    #[serde(rename = "max-context")]
174    pub max_context: i32,
175    /// Maximum segment length in characters. Defaults to 0.
176    #[serde(rename = "max-len")]
177    pub max_len: u64,
178    /// Split on word rather than on token. Defaults to false.
179    #[serde(rename = "split-on-word")]
180    pub split_on_word: bool,
181    /// Output result in a text file. Defaults to false.
182    pub output_txt: bool,
183    /// Output result in a vtt file. Defaults to false.
184    pub output_vtt: bool,
185    /// Output result in a srt file. Defaults to false.
186    pub output_srt: bool,
187    /// Output result in a lrc file. Defaults to false.
188    pub output_lrc: bool,
189    /// Output result in a CSV file. Defaults to false.
190    pub output_csv: bool,
191    /// Output result in a JSON file. Defaults to false.
192    pub output_json: bool,
193    /// Sampling temperature, between 0 and 1. Defaults to 0.00.
194    pub temperature: f64,
195    /// Automatically detect the spoken language in the provided audio input.
196    #[serde(rename = "detect-language")]
197    pub detect_language: bool,
198    /// Text to guide the model. The max length is n_text_ctx/2 tokens.
199    #[serde(skip_serializing_if = "Option::is_none")]
200    pub prompt: Option<String>,
201}
202impl Default for WhisperMetadata {
203    fn default() -> Self {
204        Self {
205            model_name: String::new(),
206            model_alias: String::new(),
207            model_path: PathBuf::new(),
208            log_enable: false,
209            debug_log: false,
210            threads: 4,
211            translate: false,
212            language: "en".to_string(),
213            processors: 1,
214            offset_time: 0,
215            duration: 0,
216            max_context: -1,
217            max_len: 0,
218            split_on_word: false,
219            output_txt: false,
220            output_vtt: false,
221            output_srt: false,
222            output_lrc: false,
223            output_csv: false,
224            output_json: false,
225            temperature: 0.0,
226            detect_language: false,
227            prompt: None,
228        }
229    }
230}
231impl BaseMetadata for WhisperMetadata {
232    fn model_name(&self) -> &str {
233        &self.model_name
234    }
235
236    fn model_alias(&self) -> &str {
237        &self.model_alias
238    }
239}