1use crate::{embeddings::embeddings, error::LlamaCoreError, running_mode, RunningMode};
4use endpoints::{
5 embeddings::{EmbeddingObject, EmbeddingRequest, EmbeddingsResponse, InputText},
6 rag::vector_search::{DataFrom, RagScoredPoint, RetrieveObject},
7};
8use qdrant::*;
9use serde_json::Value;
10use std::collections::HashSet;
11
12pub async fn rag_doc_chunks_to_embeddings(
22 embedding_request: &EmbeddingRequest,
23) -> Result<EmbeddingsResponse, LlamaCoreError> {
24 #[cfg(feature = "logging")]
25 info!(target: "stdout", "Convert document chunks to embeddings.");
26
27 let running_mode = running_mode()?;
28 if running_mode != RunningMode::RAG {
29 let err_msg =
30 format!("Creating knowledge base is not supported in the {running_mode} mode.");
31
32 #[cfg(feature = "logging")]
33 error!(target: "stdout", "{}", &err_msg);
34
35 return Err(LlamaCoreError::Operation(err_msg));
36 }
37
38 let qdrant_url = match embedding_request.vdb_server_url.as_deref() {
39 Some(url) => url.to_string(),
40 None => {
41 let err_msg = "The VectorDB server URL is not provided.";
42
43 #[cfg(feature = "logging")]
44 error!(target: "stdout", "{}", &err_msg);
45
46 return Err(LlamaCoreError::Operation(err_msg.into()));
47 }
48 };
49 let qdrant_collection_name = match embedding_request.vdb_collection_name.as_deref() {
50 Some(name) => name.to_string(),
51 None => {
52 let err_msg = "The VectorDB collection name is not provided.";
53
54 #[cfg(feature = "logging")]
55 error!(target: "stdout", "{}", &err_msg);
56
57 return Err(LlamaCoreError::Operation(err_msg.into()));
58 }
59 };
60
61 #[cfg(feature = "logging")]
62 info!(target: "stdout", "Compute embeddings for document chunks.");
63
64 #[cfg(feature = "logging")]
65 if let Ok(request_str) = serde_json::to_string(&embedding_request) {
66 debug!(target: "stdout", "Embedding request: {request_str}");
67 }
68
69 let embeddings_response = embeddings(embedding_request).await?;
71 let embeddings = embeddings_response.data.as_slice();
72 let dim = embeddings[0].embedding.len();
73
74 let mut qdrant_client = qdrant::Qdrant::new_with_url(qdrant_url);
76
77 if let Some(key) = embedding_request.vdb_api_key.as_deref() {
79 if !key.is_empty() {
80 #[cfg(feature = "logging")]
81 debug!(target: "stdout", "Set the API key for the VectorDB server.");
82
83 qdrant_client.set_api_key(key);
84 }
85 }
86
87 qdrant_create_collection(&qdrant_client, &qdrant_collection_name, dim).await?;
89
90 let chunks = match &embedding_request.input {
91 InputText::String(text) => vec![text.clone()],
92 InputText::ArrayOfStrings(texts) => texts.clone(),
93 InputText::ArrayOfTokens(tokens) => tokens.iter().map(|t| t.to_string()).collect(),
94 InputText::ArrayOfTokenArrays(token_arrays) => token_arrays
95 .iter()
96 .map(|tokens| tokens.iter().map(|t| t.to_string()).collect())
97 .collect(),
98 };
99
100 qdrant_persist_embeddings(
102 &qdrant_client,
103 &qdrant_collection_name,
104 embeddings,
105 chunks.as_slice(),
106 )
107 .await?;
108
109 Ok(embeddings_response)
110}
111
112pub async fn rag_query_to_embeddings(
118 embedding_request: &EmbeddingRequest,
119) -> Result<EmbeddingsResponse, LlamaCoreError> {
120 #[cfg(feature = "logging")]
121 info!(target: "stdout", "Compute embeddings for the user query.");
122
123 let running_mode = running_mode()?;
124 if running_mode != RunningMode::RAG {
125 let err_msg = format!("The RAG query is not supported in the {running_mode} mode.",);
126
127 #[cfg(feature = "logging")]
128 error!(target: "stdout", "{}", &err_msg);
129
130 return Err(LlamaCoreError::Operation(err_msg));
131 }
132
133 embeddings(embedding_request).await
134}
135
136pub async fn rag_retrieve_context(
150 query_embedding: &[f32],
151 vdb_server_url: impl AsRef<str>,
152 vdb_collection_name: impl AsRef<str>,
153 limit: usize,
154 score_threshold: Option<f32>,
155 vdb_api_key: Option<String>,
156) -> Result<RetrieveObject, LlamaCoreError> {
157 #[cfg(feature = "logging")]
158 {
159 info!(target: "stdout", "Retrieve context.");
160
161 info!(target: "stdout", "qdrant_url: {}, qdrant_collection_name: {}, limit: {}, score_threshold: {}", vdb_server_url.as_ref(), vdb_collection_name.as_ref(), limit, score_threshold.unwrap_or_default());
162 }
163
164 let running_mode = running_mode()?;
165 if running_mode != RunningMode::RAG {
166 let err_msg = format!("The context retrieval is not supported in the {running_mode} mode.");
167
168 #[cfg(feature = "logging")]
169 error!(target: "stdout", "{}", &err_msg);
170
171 return Err(LlamaCoreError::Operation(err_msg));
172 }
173
174 let mut qdrant_client = qdrant::Qdrant::new_with_url(vdb_server_url.as_ref().to_string());
176
177 if let Some(key) = vdb_api_key.as_deref() {
179 if !key.is_empty() {
180 #[cfg(feature = "logging")]
181 debug!(target: "stdout", "Set the API key for the VectorDB server.");
182
183 qdrant_client.set_api_key(key);
184 }
185 }
186
187 let scored_points = match qdrant_search_similar_points(
189 &qdrant_client,
190 vdb_collection_name.as_ref(),
191 query_embedding,
192 limit,
193 score_threshold,
194 )
195 .await
196 {
197 Ok(points) => points,
198 Err(e) => {
199 #[cfg(feature = "logging")]
200 error!(target: "stdout", "{e}");
201
202 return Err(e);
203 }
204 };
205
206 #[cfg(feature = "logging")]
207 info!(target: "stdout", "remove duplicates from {} scored points", scored_points.len());
208
209 let mut seen = HashSet::new();
211 let unique_scored_points: Vec<ScoredPoint> = scored_points
212 .into_iter()
213 .filter(|point| {
214 seen.insert(
215 point
216 .payload
217 .as_ref()
218 .unwrap()
219 .get("source")
220 .unwrap()
221 .to_string(),
222 )
223 })
224 .collect();
225
226 #[cfg(feature = "logging")]
227 info!(target: "stdout", "number of unique scored points: {}", unique_scored_points.len());
228
229 let ro = match unique_scored_points.is_empty() {
230 true => RetrieveObject {
231 points: None,
232 limit,
233 score_threshold: score_threshold.unwrap_or(0.0),
234 },
235 false => {
236 let mut points: Vec<RagScoredPoint> = vec![];
237 for point in unique_scored_points.iter() {
238 if let Some(payload) = &point.payload {
239 if let Some(source) = payload.get("source").and_then(Value::as_str) {
240 points.push(RagScoredPoint {
241 source: source.to_string(),
242 score: point.score as f64,
243 from: DataFrom::VectorSearch,
244 })
245 }
246
247 #[cfg(feature = "logging")]
249 if let Some(search) = payload.get("search").and_then(Value::as_str) {
250 info!(target: "stdout", "search: {search}");
251 }
252 }
253 }
254
255 RetrieveObject {
256 points: Some(points),
257 limit,
258 score_threshold: score_threshold.unwrap_or(0.0),
259 }
260 }
261 };
262
263 Ok(ro)
264}
265
266async fn qdrant_create_collection(
267 qdrant_client: &qdrant::Qdrant,
268 collection_name: impl AsRef<str>,
269 dim: usize,
270) -> Result<(), LlamaCoreError> {
271 #[cfg(feature = "logging")]
272 info!(target: "stdout", "Create a Qdrant collection named {} of {} dimensions.", collection_name.as_ref(), dim);
273
274 if let Err(e) = qdrant_client
275 .create_collection(collection_name.as_ref(), dim as u32)
276 .await
277 {
278 let err_msg = e.to_string();
279
280 #[cfg(feature = "logging")]
281 error!(target: "stdout", "{}", &err_msg);
282
283 return Err(LlamaCoreError::Qdrant(err_msg));
284 }
285
286 Ok(())
287}
288
289async fn qdrant_persist_embeddings(
290 qdrant_client: &qdrant::Qdrant,
291 collection_name: impl AsRef<str>,
292 embeddings: &[EmbeddingObject],
293 chunks: &[String],
294) -> Result<(), LlamaCoreError> {
295 #[cfg(feature = "logging")]
296 info!(target: "stdout", "Persist embeddings to the Qdrant instance.");
297
298 let mut points = Vec::<Point>::new();
299 for embedding in embeddings {
300 let vector: Vec<_> = embedding.embedding.iter().map(|x| *x as f32).collect();
302
303 let payload = serde_json::json!({"source": chunks[embedding.index as usize]})
305 .as_object()
306 .map(|m| m.to_owned());
307
308 let p = Point {
310 id: PointId::Num(embedding.index),
311 vector,
312 payload,
313 };
314
315 points.push(p);
316 }
317
318 #[cfg(feature = "logging")]
319 info!(target: "stdout", "Number of points to be upserted: {}", points.len());
320
321 if let Err(e) = qdrant_client
322 .upsert_points(collection_name.as_ref(), points)
323 .await
324 {
325 let err_msg = format!("{e}");
326
327 #[cfg(feature = "logging")]
328 error!(target: "stdout", "{}", &err_msg);
329
330 return Err(LlamaCoreError::Qdrant(err_msg));
331 }
332
333 Ok(())
334}
335
336async fn qdrant_search_similar_points(
337 qdrant_client: &qdrant::Qdrant,
338 collection_name: impl AsRef<str>,
339 query_vector: &[f32],
340 limit: usize,
341 score_threshold: Option<f32>,
342) -> Result<Vec<ScoredPoint>, LlamaCoreError> {
343 #[cfg(feature = "logging")]
344 info!(target: "stdout", "Search similar points from the qdrant instance.");
345
346 match qdrant_client
347 .search_points(
348 collection_name.as_ref(),
349 query_vector.to_vec(),
350 limit as u64,
351 score_threshold,
352 )
353 .await
354 {
355 Ok(search_result) => {
356 #[cfg(feature = "logging")]
357 info!(target: "stdout", "Number of similar points found: {}", search_result.len());
358
359 Ok(search_result)
360 }
361 Err(e) => {
362 let err_msg = e.to_string();
363
364 #[cfg(feature = "logging")]
365 error!(target: "stdout", "{}", &err_msg);
366
367 Err(LlamaCoreError::Qdrant(err_msg))
368 }
369 }
370}