1use crate::{embeddings::embeddings, error::LlamaCoreError, running_mode, RunningMode};
4use endpoints::{
5 embeddings::{EmbeddingObject, EmbeddingRequest, EmbeddingsResponse, InputText},
6 rag::{RagScoredPoint, RetrieveObject},
7};
8use qdrant::*;
9use serde_json::Value;
10use std::collections::HashSet;
11
12pub async fn rag_doc_chunks_to_embeddings(
22 embedding_request: &EmbeddingRequest,
23) -> Result<EmbeddingsResponse, LlamaCoreError> {
24 #[cfg(feature = "logging")]
25 info!(target: "stdout", "Convert document chunks to embeddings.");
26
27 let running_mode = running_mode()?;
28 if running_mode != RunningMode::RAG {
29 let err_msg =
30 format!("Creating knowledge base is not supported in the {running_mode} mode.");
31
32 #[cfg(feature = "logging")]
33 error!(target: "stdout", "{}", &err_msg);
34
35 return Err(LlamaCoreError::Operation(err_msg));
36 }
37
38 let qdrant_url = match embedding_request.vdb_server_url.as_deref() {
39 Some(url) => url.to_string(),
40 None => {
41 let err_msg = "The VectorDB server URL is not provided.";
42
43 #[cfg(feature = "logging")]
44 error!(target: "stdout", "{}", &err_msg);
45
46 return Err(LlamaCoreError::Operation(err_msg.into()));
47 }
48 };
49 let qdrant_collection_name = match embedding_request.vdb_collection_name.as_deref() {
50 Some(name) => name.to_string(),
51 None => {
52 let err_msg = "The VectorDB collection name is not provided.";
53
54 #[cfg(feature = "logging")]
55 error!(target: "stdout", "{}", &err_msg);
56
57 return Err(LlamaCoreError::Operation(err_msg.into()));
58 }
59 };
60
61 #[cfg(feature = "logging")]
62 info!(target: "stdout", "Compute embeddings for document chunks.");
63
64 #[cfg(feature = "logging")]
65 if let Ok(request_str) = serde_json::to_string(&embedding_request) {
66 debug!(target: "stdout", "Embedding request: {request_str}");
67 }
68
69 let embeddings_response = embeddings(embedding_request).await?;
71 let embeddings = embeddings_response.data.as_slice();
72 let dim = embeddings[0].embedding.len();
73
74 let mut qdrant_client = qdrant::Qdrant::new_with_url(qdrant_url);
76
77 if let Some(key) = embedding_request.vdb_api_key.as_deref() {
79 if !key.is_empty() {
80 #[cfg(feature = "logging")]
81 debug!(target: "stdout", "Set the API key for the VectorDB server.");
82
83 qdrant_client.set_api_key(key);
84 }
85 }
86
87 qdrant_create_collection(&qdrant_client, &qdrant_collection_name, dim).await?;
89
90 let chunks = match &embedding_request.input {
91 InputText::String(text) => vec![text.clone()],
92 InputText::ArrayOfStrings(texts) => texts.clone(),
93 InputText::ArrayOfTokens(tokens) => tokens.iter().map(|t| t.to_string()).collect(),
94 InputText::ArrayOfTokenArrays(token_arrays) => token_arrays
95 .iter()
96 .map(|tokens| tokens.iter().map(|t| t.to_string()).collect())
97 .collect(),
98 };
99
100 qdrant_persist_embeddings(
102 &qdrant_client,
103 &qdrant_collection_name,
104 embeddings,
105 chunks.as_slice(),
106 )
107 .await?;
108
109 Ok(embeddings_response)
110}
111
112pub async fn rag_query_to_embeddings(
118 embedding_request: &EmbeddingRequest,
119) -> Result<EmbeddingsResponse, LlamaCoreError> {
120 #[cfg(feature = "logging")]
121 info!(target: "stdout", "Compute embeddings for the user query.");
122
123 let running_mode = running_mode()?;
124 if running_mode != RunningMode::RAG {
125 let err_msg = format!("The RAG query is not supported in the {running_mode} mode.",);
126
127 #[cfg(feature = "logging")]
128 error!(target: "stdout", "{}", &err_msg);
129
130 return Err(LlamaCoreError::Operation(err_msg));
131 }
132
133 embeddings(embedding_request).await
134}
135
136pub async fn rag_retrieve_context(
150 query_embedding: &[f32],
151 vdb_server_url: impl AsRef<str>,
152 vdb_collection_name: impl AsRef<str>,
153 limit: usize,
154 score_threshold: Option<f32>,
155 vdb_api_key: Option<String>,
156) -> Result<RetrieveObject, LlamaCoreError> {
157 #[cfg(feature = "logging")]
158 {
159 info!(target: "stdout", "Retrieve context.");
160
161 info!(target: "stdout", "qdrant_url: {}, qdrant_collection_name: {}, limit: {}, score_threshold: {}", vdb_server_url.as_ref(), vdb_collection_name.as_ref(), limit, score_threshold.unwrap_or_default());
162 }
163
164 let running_mode = running_mode()?;
165 if running_mode != RunningMode::RAG {
166 let err_msg = format!("The context retrieval is not supported in the {running_mode} mode.");
167
168 #[cfg(feature = "logging")]
169 error!(target: "stdout", "{}", &err_msg);
170
171 return Err(LlamaCoreError::Operation(err_msg));
172 }
173
174 let mut qdrant_client = qdrant::Qdrant::new_with_url(vdb_server_url.as_ref().to_string());
176
177 if let Some(key) = vdb_api_key.as_deref() {
179 if !key.is_empty() {
180 #[cfg(feature = "logging")]
181 debug!(target: "stdout", "Set the API key for the VectorDB server.");
182
183 qdrant_client.set_api_key(key);
184 }
185 }
186
187 let scored_points = match qdrant_search_similar_points(
189 &qdrant_client,
190 vdb_collection_name.as_ref(),
191 query_embedding,
192 limit,
193 score_threshold,
194 )
195 .await
196 {
197 Ok(points) => points,
198 Err(e) => {
199 #[cfg(feature = "logging")]
200 error!(target: "stdout", "{e}");
201
202 return Err(e);
203 }
204 };
205
206 #[cfg(feature = "logging")]
207 info!(target: "stdout", "remove duplicates from {} scored points", scored_points.len());
208
209 let mut seen = HashSet::new();
211 let unique_scored_points: Vec<ScoredPoint> = scored_points
212 .into_iter()
213 .filter(|point| {
214 seen.insert(
215 point
216 .payload
217 .as_ref()
218 .unwrap()
219 .get("source")
220 .unwrap()
221 .to_string(),
222 )
223 })
224 .collect();
225
226 #[cfg(feature = "logging")]
227 info!(target: "stdout", "number of unique scored points: {}", unique_scored_points.len());
228
229 let ro = match unique_scored_points.is_empty() {
230 true => RetrieveObject {
231 points: None,
232 limit,
233 score_threshold: score_threshold.unwrap_or(0.0),
234 },
235 false => {
236 let mut points: Vec<RagScoredPoint> = vec![];
237 for point in unique_scored_points.iter() {
238 if let Some(payload) = &point.payload {
239 if let Some(source) = payload.get("source").and_then(Value::as_str) {
240 points.push(RagScoredPoint {
241 source: source.to_string(),
242 score: point.score,
243 })
244 }
245
246 #[cfg(feature = "logging")]
248 if let Some(search) = payload.get("search").and_then(Value::as_str) {
249 info!(target: "stdout", "search: {search}");
250 }
251 }
252 }
253
254 RetrieveObject {
255 points: Some(points),
256 limit,
257 score_threshold: score_threshold.unwrap_or(0.0),
258 }
259 }
260 };
261
262 Ok(ro)
263}
264
265async fn qdrant_create_collection(
266 qdrant_client: &qdrant::Qdrant,
267 collection_name: impl AsRef<str>,
268 dim: usize,
269) -> Result<(), LlamaCoreError> {
270 #[cfg(feature = "logging")]
271 info!(target: "stdout", "Create a Qdrant collection named {} of {} dimensions.", collection_name.as_ref(), dim);
272
273 if let Err(e) = qdrant_client
274 .create_collection(collection_name.as_ref(), dim as u32)
275 .await
276 {
277 let err_msg = e.to_string();
278
279 #[cfg(feature = "logging")]
280 error!(target: "stdout", "{}", &err_msg);
281
282 return Err(LlamaCoreError::Qdrant(err_msg));
283 }
284
285 Ok(())
286}
287
288async fn qdrant_persist_embeddings(
289 qdrant_client: &qdrant::Qdrant,
290 collection_name: impl AsRef<str>,
291 embeddings: &[EmbeddingObject],
292 chunks: &[String],
293) -> Result<(), LlamaCoreError> {
294 #[cfg(feature = "logging")]
295 info!(target: "stdout", "Persist embeddings to the Qdrant instance.");
296
297 let mut points = Vec::<Point>::new();
298 for embedding in embeddings {
299 let vector: Vec<_> = embedding.embedding.iter().map(|x| *x as f32).collect();
301
302 let payload = serde_json::json!({"source": chunks[embedding.index as usize]})
304 .as_object()
305 .map(|m| m.to_owned());
306
307 let p = Point {
309 id: PointId::Num(embedding.index),
310 vector,
311 payload,
312 };
313
314 points.push(p);
315 }
316
317 #[cfg(feature = "logging")]
318 info!(target: "stdout", "Number of points to be upserted: {}", points.len());
319
320 if let Err(e) = qdrant_client
321 .upsert_points(collection_name.as_ref(), points)
322 .await
323 {
324 let err_msg = format!("{e}");
325
326 #[cfg(feature = "logging")]
327 error!(target: "stdout", "{}", &err_msg);
328
329 return Err(LlamaCoreError::Qdrant(err_msg));
330 }
331
332 Ok(())
333}
334
335async fn qdrant_search_similar_points(
336 qdrant_client: &qdrant::Qdrant,
337 collection_name: impl AsRef<str>,
338 query_vector: &[f32],
339 limit: usize,
340 score_threshold: Option<f32>,
341) -> Result<Vec<ScoredPoint>, LlamaCoreError> {
342 #[cfg(feature = "logging")]
343 info!(target: "stdout", "Search similar points from the qdrant instance.");
344
345 match qdrant_client
346 .search_points(
347 collection_name.as_ref(),
348 query_vector.to_vec(),
349 limit as u64,
350 score_threshold,
351 )
352 .await
353 {
354 Ok(search_result) => {
355 #[cfg(feature = "logging")]
356 info!(target: "stdout", "Number of similar points found: {}", search_result.len());
357
358 Ok(search_result)
359 }
360 Err(e) => {
361 let err_msg = e.to_string();
362
363 #[cfg(feature = "logging")]
364 error!(target: "stdout", "{}", &err_msg);
365
366 Err(LlamaCoreError::Qdrant(err_msg))
367 }
368 }
369}