1use crate::{embeddings::embeddings, error::LlamaCoreError, running_mode, RunningMode};
4use endpoints::{
5 embeddings::{EmbeddingObject, EmbeddingRequest, EmbeddingsResponse, InputText},
6 rag::{RagScoredPoint, RetrieveObject},
7};
8use qdrant::*;
9use serde_json::Value;
10use std::collections::HashSet;
11
12pub async fn rag_doc_chunks_to_embeddings(
22 embedding_request: &EmbeddingRequest,
23) -> Result<EmbeddingsResponse, LlamaCoreError> {
24 #[cfg(feature = "logging")]
25 info!(target: "stdout", "Convert document chunks to embeddings.");
26
27 let running_mode = running_mode()?;
28 if running_mode != RunningMode::RAG {
29 let err_msg = format!(
30 "Creating knowledge base is not supported in the {} mode.",
31 running_mode
32 );
33
34 #[cfg(feature = "logging")]
35 error!(target: "stdout", "{}", &err_msg);
36
37 return Err(LlamaCoreError::Operation(err_msg));
38 }
39
40 let qdrant_url = match embedding_request.vdb_server_url.as_deref() {
41 Some(url) => url.to_string(),
42 None => {
43 let err_msg = "The VectorDB server URL is not provided.";
44
45 #[cfg(feature = "logging")]
46 error!(target: "stdout", "{}", &err_msg);
47
48 return Err(LlamaCoreError::Operation(err_msg.into()));
49 }
50 };
51 let qdrant_collection_name = match embedding_request.vdb_collection_name.as_deref() {
52 Some(name) => name.to_string(),
53 None => {
54 let err_msg = "The VectorDB collection name is not provided.";
55
56 #[cfg(feature = "logging")]
57 error!(target: "stdout", "{}", &err_msg);
58
59 return Err(LlamaCoreError::Operation(err_msg.into()));
60 }
61 };
62
63 #[cfg(feature = "logging")]
64 info!(target: "stdout", "Compute embeddings for document chunks.");
65
66 #[cfg(feature = "logging")]
67 if let Ok(request_str) = serde_json::to_string(&embedding_request) {
68 debug!(target: "stdout", "Embedding request: {}", request_str);
69 }
70
71 let embeddings_response = embeddings(embedding_request).await?;
73 let embeddings = embeddings_response.data.as_slice();
74 let dim = embeddings[0].embedding.len();
75
76 let mut qdrant_client = qdrant::Qdrant::new_with_url(qdrant_url);
78
79 if let Some(key) = embedding_request.vdb_api_key.as_deref() {
81 if !key.is_empty() {
82 #[cfg(feature = "logging")]
83 debug!(target: "stdout", "Set the API key for the VectorDB server.");
84
85 qdrant_client.set_api_key(key);
86 }
87 }
88
89 qdrant_create_collection(&qdrant_client, &qdrant_collection_name, dim).await?;
91
92 let chunks = match &embedding_request.input {
93 InputText::String(text) => vec![text.clone()],
94 InputText::ArrayOfStrings(texts) => texts.clone(),
95 InputText::ArrayOfTokens(tokens) => tokens.iter().map(|t| t.to_string()).collect(),
96 InputText::ArrayOfTokenArrays(token_arrays) => token_arrays
97 .iter()
98 .map(|tokens| tokens.iter().map(|t| t.to_string()).collect())
99 .collect(),
100 };
101
102 qdrant_persist_embeddings(
104 &qdrant_client,
105 &qdrant_collection_name,
106 embeddings,
107 chunks.as_slice(),
108 )
109 .await?;
110
111 Ok(embeddings_response)
112}
113
114pub async fn rag_query_to_embeddings(
120 embedding_request: &EmbeddingRequest,
121) -> Result<EmbeddingsResponse, LlamaCoreError> {
122 #[cfg(feature = "logging")]
123 info!(target: "stdout", "Compute embeddings for the user query.");
124
125 let running_mode = running_mode()?;
126 if running_mode != RunningMode::RAG {
127 let err_msg = format!("The RAG query is not supported in the {running_mode} mode.",);
128
129 #[cfg(feature = "logging")]
130 error!(target: "stdout", "{}", &err_msg);
131
132 return Err(LlamaCoreError::Operation(err_msg));
133 }
134
135 embeddings(embedding_request).await
136}
137
138pub async fn rag_retrieve_context(
152 query_embedding: &[f32],
153 vdb_server_url: impl AsRef<str>,
154 vdb_collection_name: impl AsRef<str>,
155 limit: usize,
156 score_threshold: Option<f32>,
157 vdb_api_key: Option<String>,
158) -> Result<RetrieveObject, LlamaCoreError> {
159 #[cfg(feature = "logging")]
160 {
161 info!(target: "stdout", "Retrieve context.");
162
163 info!(target: "stdout", "qdrant_url: {}, qdrant_collection_name: {}, limit: {}, score_threshold: {}", vdb_server_url.as_ref(), vdb_collection_name.as_ref(), limit, score_threshold.unwrap_or_default());
164 }
165
166 let running_mode = running_mode()?;
167 if running_mode != RunningMode::RAG {
168 let err_msg = format!(
169 "The context retrieval is not supported in the {} mode.",
170 running_mode
171 );
172
173 #[cfg(feature = "logging")]
174 error!(target: "stdout", "{}", &err_msg);
175
176 return Err(LlamaCoreError::Operation(err_msg));
177 }
178
179 let mut qdrant_client = qdrant::Qdrant::new_with_url(vdb_server_url.as_ref().to_string());
181
182 if let Some(key) = vdb_api_key.as_deref() {
184 if !key.is_empty() {
185 #[cfg(feature = "logging")]
186 debug!(target: "stdout", "Set the API key for the VectorDB server.");
187
188 qdrant_client.set_api_key(key);
189 }
190 }
191
192 let scored_points = match qdrant_search_similar_points(
194 &qdrant_client,
195 vdb_collection_name.as_ref(),
196 query_embedding,
197 limit,
198 score_threshold,
199 )
200 .await
201 {
202 Ok(points) => points,
203 Err(e) => {
204 #[cfg(feature = "logging")]
205 error!(target: "stdout", "{}", e.to_string());
206
207 return Err(e);
208 }
209 };
210
211 #[cfg(feature = "logging")]
212 info!(target: "stdout", "remove duplicates from {} scored points", scored_points.len());
213
214 let mut seen = HashSet::new();
216 let unique_scored_points: Vec<ScoredPoint> = scored_points
217 .into_iter()
218 .filter(|point| {
219 seen.insert(
220 point
221 .payload
222 .as_ref()
223 .unwrap()
224 .get("source")
225 .unwrap()
226 .to_string(),
227 )
228 })
229 .collect();
230
231 #[cfg(feature = "logging")]
232 info!(target: "stdout", "number of unique scored points: {}", unique_scored_points.len());
233
234 let ro = match unique_scored_points.is_empty() {
235 true => RetrieveObject {
236 points: None,
237 limit,
238 score_threshold: score_threshold.unwrap_or(0.0),
239 },
240 false => {
241 let mut points: Vec<RagScoredPoint> = vec![];
242 for point in unique_scored_points.iter() {
243 if let Some(payload) = &point.payload {
244 if let Some(source) = payload.get("source").and_then(Value::as_str) {
245 points.push(RagScoredPoint {
246 source: source.to_string(),
247 score: point.score,
248 })
249 }
250
251 #[cfg(feature = "logging")]
253 if let Some(search) = payload.get("search").and_then(Value::as_str) {
254 info!(target: "stdout", "search: {}", search);
255 }
256 }
257 }
258
259 RetrieveObject {
260 points: Some(points),
261 limit,
262 score_threshold: score_threshold.unwrap_or(0.0),
263 }
264 }
265 };
266
267 Ok(ro)
268}
269
270async fn qdrant_create_collection(
271 qdrant_client: &qdrant::Qdrant,
272 collection_name: impl AsRef<str>,
273 dim: usize,
274) -> Result<(), LlamaCoreError> {
275 #[cfg(feature = "logging")]
276 info!(target: "stdout", "Create a Qdrant collection named {} of {} dimensions.", collection_name.as_ref(), dim);
277
278 if let Err(e) = qdrant_client
279 .create_collection(collection_name.as_ref(), dim as u32)
280 .await
281 {
282 let err_msg = e.to_string();
283
284 #[cfg(feature = "logging")]
285 error!(target: "stdout", "{}", &err_msg);
286
287 return Err(LlamaCoreError::Qdrant(err_msg));
288 }
289
290 Ok(())
291}
292
293async fn qdrant_persist_embeddings(
294 qdrant_client: &qdrant::Qdrant,
295 collection_name: impl AsRef<str>,
296 embeddings: &[EmbeddingObject],
297 chunks: &[String],
298) -> Result<(), LlamaCoreError> {
299 #[cfg(feature = "logging")]
300 info!(target: "stdout", "Persist embeddings to the Qdrant instance.");
301
302 let mut points = Vec::<Point>::new();
303 for embedding in embeddings {
304 let vector: Vec<_> = embedding.embedding.iter().map(|x| *x as f32).collect();
306
307 let payload = serde_json::json!({"source": chunks[embedding.index as usize]})
309 .as_object()
310 .map(|m| m.to_owned());
311
312 let p = Point {
314 id: PointId::Num(embedding.index),
315 vector,
316 payload,
317 };
318
319 points.push(p);
320 }
321
322 #[cfg(feature = "logging")]
323 info!(target: "stdout", "Number of points to be upserted: {}", points.len());
324
325 if let Err(e) = qdrant_client
326 .upsert_points(collection_name.as_ref(), points)
327 .await
328 {
329 let err_msg = format!("{}", e);
330
331 #[cfg(feature = "logging")]
332 error!(target: "stdout", "{}", &err_msg);
333
334 return Err(LlamaCoreError::Qdrant(err_msg));
335 }
336
337 Ok(())
338}
339
340async fn qdrant_search_similar_points(
341 qdrant_client: &qdrant::Qdrant,
342 collection_name: impl AsRef<str>,
343 query_vector: &[f32],
344 limit: usize,
345 score_threshold: Option<f32>,
346) -> Result<Vec<ScoredPoint>, LlamaCoreError> {
347 #[cfg(feature = "logging")]
348 info!(target: "stdout", "Search similar points from the qdrant instance.");
349
350 match qdrant_client
351 .search_points(
352 collection_name.as_ref(),
353 query_vector.to_vec(),
354 limit as u64,
355 score_threshold,
356 )
357 .await
358 {
359 Ok(search_result) => {
360 #[cfg(feature = "logging")]
361 info!(target: "stdout", "Number of similar points found: {}", search_result.len());
362
363 Ok(search_result)
364 }
365 Err(e) => {
366 let err_msg = e.to_string();
367
368 #[cfg(feature = "logging")]
369 error!(target: "stdout", "{}", &err_msg);
370
371 Err(LlamaCoreError::Qdrant(err_msg))
372 }
373 }
374}