llama_core/
rag.rs

1//! Define APIs for RAG operations.
2
3use crate::{embeddings::embeddings, error::LlamaCoreError, running_mode, RunningMode};
4use endpoints::{
5    embeddings::{EmbeddingObject, EmbeddingRequest, EmbeddingsResponse, InputText},
6    rag::{RagScoredPoint, RetrieveObject},
7};
8use qdrant::*;
9use serde_json::Value;
10use std::collections::HashSet;
11
12/// Convert document chunks to embeddings.
13///
14/// # Arguments
15///
16/// * `embedding_request` - A reference to an `EmbeddingRequest` object.
17///
18/// # Returns
19///
20/// Name of the Qdrant collection if successful.
21pub async fn rag_doc_chunks_to_embeddings(
22    embedding_request: &EmbeddingRequest,
23) -> Result<EmbeddingsResponse, LlamaCoreError> {
24    #[cfg(feature = "logging")]
25    info!(target: "stdout", "Convert document chunks to embeddings.");
26
27    let running_mode = running_mode()?;
28    if running_mode != RunningMode::RAG {
29        let err_msg =
30            format!("Creating knowledge base is not supported in the {running_mode} mode.");
31
32        #[cfg(feature = "logging")]
33        error!(target: "stdout", "{}", &err_msg);
34
35        return Err(LlamaCoreError::Operation(err_msg));
36    }
37
38    let qdrant_url = match embedding_request.vdb_server_url.as_deref() {
39        Some(url) => url.to_string(),
40        None => {
41            let err_msg = "The VectorDB server URL is not provided.";
42
43            #[cfg(feature = "logging")]
44            error!(target: "stdout", "{}", &err_msg);
45
46            return Err(LlamaCoreError::Operation(err_msg.into()));
47        }
48    };
49    let qdrant_collection_name = match embedding_request.vdb_collection_name.as_deref() {
50        Some(name) => name.to_string(),
51        None => {
52            let err_msg = "The VectorDB collection name is not provided.";
53
54            #[cfg(feature = "logging")]
55            error!(target: "stdout", "{}", &err_msg);
56
57            return Err(LlamaCoreError::Operation(err_msg.into()));
58        }
59    };
60
61    #[cfg(feature = "logging")]
62    info!(target: "stdout", "Compute embeddings for document chunks.");
63
64    #[cfg(feature = "logging")]
65    if let Ok(request_str) = serde_json::to_string(&embedding_request) {
66        debug!(target: "stdout", "Embedding request: {request_str}");
67    }
68
69    // compute embeddings for the document
70    let embeddings_response = embeddings(embedding_request).await?;
71    let embeddings = embeddings_response.data.as_slice();
72    let dim = embeddings[0].embedding.len();
73
74    // create a Qdrant client
75    let mut qdrant_client = qdrant::Qdrant::new_with_url(qdrant_url);
76
77    // set the API key if provided
78    if let Some(key) = embedding_request.vdb_api_key.as_deref() {
79        if !key.is_empty() {
80            #[cfg(feature = "logging")]
81            debug!(target: "stdout", "Set the API key for the VectorDB server.");
82
83            qdrant_client.set_api_key(key);
84        }
85    }
86
87    // create a collection
88    qdrant_create_collection(&qdrant_client, &qdrant_collection_name, dim).await?;
89
90    let chunks = match &embedding_request.input {
91        InputText::String(text) => vec![text.clone()],
92        InputText::ArrayOfStrings(texts) => texts.clone(),
93        InputText::ArrayOfTokens(tokens) => tokens.iter().map(|t| t.to_string()).collect(),
94        InputText::ArrayOfTokenArrays(token_arrays) => token_arrays
95            .iter()
96            .map(|tokens| tokens.iter().map(|t| t.to_string()).collect())
97            .collect(),
98    };
99
100    // create and upsert points
101    qdrant_persist_embeddings(
102        &qdrant_client,
103        &qdrant_collection_name,
104        embeddings,
105        chunks.as_slice(),
106    )
107    .await?;
108
109    Ok(embeddings_response)
110}
111
112/// Convert a query to embeddings.
113///
114/// # Arguments
115///
116/// * `embedding_request` - A reference to an `EmbeddingRequest` object.
117pub async fn rag_query_to_embeddings(
118    embedding_request: &EmbeddingRequest,
119) -> Result<EmbeddingsResponse, LlamaCoreError> {
120    #[cfg(feature = "logging")]
121    info!(target: "stdout", "Compute embeddings for the user query.");
122
123    let running_mode = running_mode()?;
124    if running_mode != RunningMode::RAG {
125        let err_msg = format!("The RAG query is not supported in the {running_mode} mode.",);
126
127        #[cfg(feature = "logging")]
128        error!(target: "stdout", "{}", &err_msg);
129
130        return Err(LlamaCoreError::Operation(err_msg));
131    }
132
133    embeddings(embedding_request).await
134}
135
136/// Retrieve similar points from the Qdrant server using the query embedding
137///
138/// # Arguments
139///
140/// * `query_embedding` - A reference to a query embedding.
141///
142/// * `qdrant_url` - URL of the Qdrant server.
143///
144/// * `qdrant_collection_name` - Name of the Qdrant collection to be created.
145///
146/// * `limit` - Number of retrieved results.
147///
148/// * `score_threshold` - The minimum score of the retrieved results.
149pub async fn rag_retrieve_context(
150    query_embedding: &[f32],
151    vdb_server_url: impl AsRef<str>,
152    vdb_collection_name: impl AsRef<str>,
153    limit: usize,
154    score_threshold: Option<f32>,
155    vdb_api_key: Option<String>,
156) -> Result<RetrieveObject, LlamaCoreError> {
157    #[cfg(feature = "logging")]
158    {
159        info!(target: "stdout", "Retrieve context.");
160
161        info!(target: "stdout", "qdrant_url: {}, qdrant_collection_name: {}, limit: {}, score_threshold: {}", vdb_server_url.as_ref(), vdb_collection_name.as_ref(), limit, score_threshold.unwrap_or_default());
162    }
163
164    let running_mode = running_mode()?;
165    if running_mode != RunningMode::RAG {
166        let err_msg = format!("The context retrieval is not supported in the {running_mode} mode.");
167
168        #[cfg(feature = "logging")]
169        error!(target: "stdout", "{}", &err_msg);
170
171        return Err(LlamaCoreError::Operation(err_msg));
172    }
173
174    // create a Qdrant client
175    let mut qdrant_client = qdrant::Qdrant::new_with_url(vdb_server_url.as_ref().to_string());
176
177    // set the API key if provided
178    if let Some(key) = vdb_api_key.as_deref() {
179        if !key.is_empty() {
180            #[cfg(feature = "logging")]
181            debug!(target: "stdout", "Set the API key for the VectorDB server.");
182
183            qdrant_client.set_api_key(key);
184        }
185    }
186
187    // search for similar points
188    let scored_points = match qdrant_search_similar_points(
189        &qdrant_client,
190        vdb_collection_name.as_ref(),
191        query_embedding,
192        limit,
193        score_threshold,
194    )
195    .await
196    {
197        Ok(points) => points,
198        Err(e) => {
199            #[cfg(feature = "logging")]
200            error!(target: "stdout", "{e}");
201
202            return Err(e);
203        }
204    };
205
206    #[cfg(feature = "logging")]
207    info!(target: "stdout", "remove duplicates from {} scored points", scored_points.len());
208
209    // remove duplicates, which have the same source
210    let mut seen = HashSet::new();
211    let unique_scored_points: Vec<ScoredPoint> = scored_points
212        .into_iter()
213        .filter(|point| {
214            seen.insert(
215                point
216                    .payload
217                    .as_ref()
218                    .unwrap()
219                    .get("source")
220                    .unwrap()
221                    .to_string(),
222            )
223        })
224        .collect();
225
226    #[cfg(feature = "logging")]
227    info!(target: "stdout", "number of unique scored points: {}", unique_scored_points.len());
228
229    let ro = match unique_scored_points.is_empty() {
230        true => RetrieveObject {
231            points: None,
232            limit,
233            score_threshold: score_threshold.unwrap_or(0.0),
234        },
235        false => {
236            let mut points: Vec<RagScoredPoint> = vec![];
237            for point in unique_scored_points.iter() {
238                if let Some(payload) = &point.payload {
239                    if let Some(source) = payload.get("source").and_then(Value::as_str) {
240                        points.push(RagScoredPoint {
241                            source: source.to_string(),
242                            score: point.score,
243                        })
244                    }
245
246                    // For debugging purpose, log the optional search field if it exists
247                    #[cfg(feature = "logging")]
248                    if let Some(search) = payload.get("search").and_then(Value::as_str) {
249                        info!(target: "stdout", "search: {search}");
250                    }
251                }
252            }
253
254            RetrieveObject {
255                points: Some(points),
256                limit,
257                score_threshold: score_threshold.unwrap_or(0.0),
258            }
259        }
260    };
261
262    Ok(ro)
263}
264
265async fn qdrant_create_collection(
266    qdrant_client: &qdrant::Qdrant,
267    collection_name: impl AsRef<str>,
268    dim: usize,
269) -> Result<(), LlamaCoreError> {
270    #[cfg(feature = "logging")]
271    info!(target: "stdout", "Create a Qdrant collection named {} of {} dimensions.", collection_name.as_ref(), dim);
272
273    if let Err(e) = qdrant_client
274        .create_collection(collection_name.as_ref(), dim as u32)
275        .await
276    {
277        let err_msg = e.to_string();
278
279        #[cfg(feature = "logging")]
280        error!(target: "stdout", "{}", &err_msg);
281
282        return Err(LlamaCoreError::Qdrant(err_msg));
283    }
284
285    Ok(())
286}
287
288async fn qdrant_persist_embeddings(
289    qdrant_client: &qdrant::Qdrant,
290    collection_name: impl AsRef<str>,
291    embeddings: &[EmbeddingObject],
292    chunks: &[String],
293) -> Result<(), LlamaCoreError> {
294    #[cfg(feature = "logging")]
295    info!(target: "stdout", "Persist embeddings to the Qdrant instance.");
296
297    let mut points = Vec::<Point>::new();
298    for embedding in embeddings {
299        // convert the embedding to a vector
300        let vector: Vec<_> = embedding.embedding.iter().map(|x| *x as f32).collect();
301
302        // create a payload
303        let payload = serde_json::json!({"source": chunks[embedding.index as usize]})
304            .as_object()
305            .map(|m| m.to_owned());
306
307        // create a point
308        let p = Point {
309            id: PointId::Num(embedding.index),
310            vector,
311            payload,
312        };
313
314        points.push(p);
315    }
316
317    #[cfg(feature = "logging")]
318    info!(target: "stdout", "Number of points to be upserted: {}", points.len());
319
320    if let Err(e) = qdrant_client
321        .upsert_points(collection_name.as_ref(), points)
322        .await
323    {
324        let err_msg = format!("{e}");
325
326        #[cfg(feature = "logging")]
327        error!(target: "stdout", "{}", &err_msg);
328
329        return Err(LlamaCoreError::Qdrant(err_msg));
330    }
331
332    Ok(())
333}
334
335async fn qdrant_search_similar_points(
336    qdrant_client: &qdrant::Qdrant,
337    collection_name: impl AsRef<str>,
338    query_vector: &[f32],
339    limit: usize,
340    score_threshold: Option<f32>,
341) -> Result<Vec<ScoredPoint>, LlamaCoreError> {
342    #[cfg(feature = "logging")]
343    info!(target: "stdout", "Search similar points from the qdrant instance.");
344
345    match qdrant_client
346        .search_points(
347            collection_name.as_ref(),
348            query_vector.to_vec(),
349            limit as u64,
350            score_threshold,
351        )
352        .await
353    {
354        Ok(search_result) => {
355            #[cfg(feature = "logging")]
356            info!(target: "stdout", "Number of similar points found: {}", search_result.len());
357
358            Ok(search_result)
359        }
360        Err(e) => {
361            let err_msg = e.to_string();
362
363            #[cfg(feature = "logging")]
364            error!(target: "stdout", "{}", &err_msg);
365
366            Err(LlamaCoreError::Qdrant(err_msg))
367        }
368    }
369}