llama_core/
rag.rs

1//! Define APIs for RAG operations.
2
3use crate::{embeddings::embeddings, error::LlamaCoreError, running_mode, RunningMode};
4use endpoints::{
5    embeddings::{EmbeddingObject, EmbeddingRequest, EmbeddingsResponse, InputText},
6    rag::{RagScoredPoint, RetrieveObject},
7};
8use qdrant::*;
9use serde_json::Value;
10use std::collections::HashSet;
11
12/// Convert document chunks to embeddings.
13///
14/// # Arguments
15///
16/// * `embedding_request` - A reference to an `EmbeddingRequest` object.
17///
18/// # Returns
19///
20/// Name of the Qdrant collection if successful.
21pub async fn rag_doc_chunks_to_embeddings(
22    embedding_request: &EmbeddingRequest,
23) -> Result<EmbeddingsResponse, LlamaCoreError> {
24    #[cfg(feature = "logging")]
25    info!(target: "stdout", "Convert document chunks to embeddings.");
26
27    let running_mode = running_mode()?;
28    if running_mode != RunningMode::RAG {
29        let err_msg = format!(
30            "Creating knowledge base is not supported in the {} mode.",
31            running_mode
32        );
33
34        #[cfg(feature = "logging")]
35        error!(target: "stdout", "{}", &err_msg);
36
37        return Err(LlamaCoreError::Operation(err_msg));
38    }
39
40    let qdrant_url = match embedding_request.vdb_server_url.as_deref() {
41        Some(url) => url.to_string(),
42        None => {
43            let err_msg = "The VectorDB server URL is not provided.";
44
45            #[cfg(feature = "logging")]
46            error!(target: "stdout", "{}", &err_msg);
47
48            return Err(LlamaCoreError::Operation(err_msg.into()));
49        }
50    };
51    let qdrant_collection_name = match embedding_request.vdb_collection_name.as_deref() {
52        Some(name) => name.to_string(),
53        None => {
54            let err_msg = "The VectorDB collection name is not provided.";
55
56            #[cfg(feature = "logging")]
57            error!(target: "stdout", "{}", &err_msg);
58
59            return Err(LlamaCoreError::Operation(err_msg.into()));
60        }
61    };
62
63    #[cfg(feature = "logging")]
64    info!(target: "stdout", "Compute embeddings for document chunks.");
65
66    #[cfg(feature = "logging")]
67    if let Ok(request_str) = serde_json::to_string(&embedding_request) {
68        debug!(target: "stdout", "Embedding request: {}", request_str);
69    }
70
71    // compute embeddings for the document
72    let embeddings_response = embeddings(embedding_request).await?;
73    let embeddings = embeddings_response.data.as_slice();
74    let dim = embeddings[0].embedding.len();
75
76    // create a Qdrant client
77    let mut qdrant_client = qdrant::Qdrant::new_with_url(qdrant_url);
78
79    // set the API key if provided
80    if let Some(key) = embedding_request.vdb_api_key.as_deref() {
81        if !key.is_empty() {
82            #[cfg(feature = "logging")]
83            debug!(target: "stdout", "Set the API key for the VectorDB server.");
84
85            qdrant_client.set_api_key(key);
86        }
87    }
88
89    // create a collection
90    qdrant_create_collection(&qdrant_client, &qdrant_collection_name, dim).await?;
91
92    let chunks = match &embedding_request.input {
93        InputText::String(text) => vec![text.clone()],
94        InputText::ArrayOfStrings(texts) => texts.clone(),
95        InputText::ArrayOfTokens(tokens) => tokens.iter().map(|t| t.to_string()).collect(),
96        InputText::ArrayOfTokenArrays(token_arrays) => token_arrays
97            .iter()
98            .map(|tokens| tokens.iter().map(|t| t.to_string()).collect())
99            .collect(),
100    };
101
102    // create and upsert points
103    qdrant_persist_embeddings(
104        &qdrant_client,
105        &qdrant_collection_name,
106        embeddings,
107        chunks.as_slice(),
108    )
109    .await?;
110
111    Ok(embeddings_response)
112}
113
114/// Convert a query to embeddings.
115///
116/// # Arguments
117///
118/// * `embedding_request` - A reference to an `EmbeddingRequest` object.
119pub async fn rag_query_to_embeddings(
120    embedding_request: &EmbeddingRequest,
121) -> Result<EmbeddingsResponse, LlamaCoreError> {
122    #[cfg(feature = "logging")]
123    info!(target: "stdout", "Compute embeddings for the user query.");
124
125    let running_mode = running_mode()?;
126    if running_mode != RunningMode::RAG {
127        let err_msg = format!("The RAG query is not supported in the {running_mode} mode.",);
128
129        #[cfg(feature = "logging")]
130        error!(target: "stdout", "{}", &err_msg);
131
132        return Err(LlamaCoreError::Operation(err_msg));
133    }
134
135    embeddings(embedding_request).await
136}
137
138/// Retrieve similar points from the Qdrant server using the query embedding
139///
140/// # Arguments
141///
142/// * `query_embedding` - A reference to a query embedding.
143///
144/// * `qdrant_url` - URL of the Qdrant server.
145///
146/// * `qdrant_collection_name` - Name of the Qdrant collection to be created.
147///
148/// * `limit` - Number of retrieved results.
149///
150/// * `score_threshold` - The minimum score of the retrieved results.
151pub async fn rag_retrieve_context(
152    query_embedding: &[f32],
153    vdb_server_url: impl AsRef<str>,
154    vdb_collection_name: impl AsRef<str>,
155    limit: usize,
156    score_threshold: Option<f32>,
157    vdb_api_key: Option<String>,
158) -> Result<RetrieveObject, LlamaCoreError> {
159    #[cfg(feature = "logging")]
160    {
161        info!(target: "stdout", "Retrieve context.");
162
163        info!(target: "stdout", "qdrant_url: {}, qdrant_collection_name: {}, limit: {}, score_threshold: {}", vdb_server_url.as_ref(), vdb_collection_name.as_ref(), limit, score_threshold.unwrap_or_default());
164    }
165
166    let running_mode = running_mode()?;
167    if running_mode != RunningMode::RAG {
168        let err_msg = format!(
169            "The context retrieval is not supported in the {} mode.",
170            running_mode
171        );
172
173        #[cfg(feature = "logging")]
174        error!(target: "stdout", "{}", &err_msg);
175
176        return Err(LlamaCoreError::Operation(err_msg));
177    }
178
179    // create a Qdrant client
180    let mut qdrant_client = qdrant::Qdrant::new_with_url(vdb_server_url.as_ref().to_string());
181
182    // set the API key if provided
183    if let Some(key) = vdb_api_key.as_deref() {
184        if !key.is_empty() {
185            #[cfg(feature = "logging")]
186            debug!(target: "stdout", "Set the API key for the VectorDB server.");
187
188            qdrant_client.set_api_key(key);
189        }
190    }
191
192    // search for similar points
193    let scored_points = match qdrant_search_similar_points(
194        &qdrant_client,
195        vdb_collection_name.as_ref(),
196        query_embedding,
197        limit,
198        score_threshold,
199    )
200    .await
201    {
202        Ok(points) => points,
203        Err(e) => {
204            #[cfg(feature = "logging")]
205            error!(target: "stdout", "{}", e.to_string());
206
207            return Err(e);
208        }
209    };
210
211    #[cfg(feature = "logging")]
212    info!(target: "stdout", "remove duplicates from {} scored points", scored_points.len());
213
214    // remove duplicates, which have the same source
215    let mut seen = HashSet::new();
216    let unique_scored_points: Vec<ScoredPoint> = scored_points
217        .into_iter()
218        .filter(|point| {
219            seen.insert(
220                point
221                    .payload
222                    .as_ref()
223                    .unwrap()
224                    .get("source")
225                    .unwrap()
226                    .to_string(),
227            )
228        })
229        .collect();
230
231    #[cfg(feature = "logging")]
232    info!(target: "stdout", "number of unique scored points: {}", unique_scored_points.len());
233
234    let ro = match unique_scored_points.is_empty() {
235        true => RetrieveObject {
236            points: None,
237            limit,
238            score_threshold: score_threshold.unwrap_or(0.0),
239        },
240        false => {
241            let mut points: Vec<RagScoredPoint> = vec![];
242            for point in unique_scored_points.iter() {
243                if let Some(payload) = &point.payload {
244                    if let Some(source) = payload.get("source").and_then(Value::as_str) {
245                        points.push(RagScoredPoint {
246                            source: source.to_string(),
247                            score: point.score,
248                        })
249                    }
250
251                    // For debugging purpose, log the optional search field if it exists
252                    #[cfg(feature = "logging")]
253                    if let Some(search) = payload.get("search").and_then(Value::as_str) {
254                        info!(target: "stdout", "search: {}", search);
255                    }
256                }
257            }
258
259            RetrieveObject {
260                points: Some(points),
261                limit,
262                score_threshold: score_threshold.unwrap_or(0.0),
263            }
264        }
265    };
266
267    Ok(ro)
268}
269
270async fn qdrant_create_collection(
271    qdrant_client: &qdrant::Qdrant,
272    collection_name: impl AsRef<str>,
273    dim: usize,
274) -> Result<(), LlamaCoreError> {
275    #[cfg(feature = "logging")]
276    info!(target: "stdout", "Create a Qdrant collection named {} of {} dimensions.", collection_name.as_ref(), dim);
277
278    if let Err(e) = qdrant_client
279        .create_collection(collection_name.as_ref(), dim as u32)
280        .await
281    {
282        let err_msg = e.to_string();
283
284        #[cfg(feature = "logging")]
285        error!(target: "stdout", "{}", &err_msg);
286
287        return Err(LlamaCoreError::Qdrant(err_msg));
288    }
289
290    Ok(())
291}
292
293async fn qdrant_persist_embeddings(
294    qdrant_client: &qdrant::Qdrant,
295    collection_name: impl AsRef<str>,
296    embeddings: &[EmbeddingObject],
297    chunks: &[String],
298) -> Result<(), LlamaCoreError> {
299    #[cfg(feature = "logging")]
300    info!(target: "stdout", "Persist embeddings to the Qdrant instance.");
301
302    let mut points = Vec::<Point>::new();
303    for embedding in embeddings {
304        // convert the embedding to a vector
305        let vector: Vec<_> = embedding.embedding.iter().map(|x| *x as f32).collect();
306
307        // create a payload
308        let payload = serde_json::json!({"source": chunks[embedding.index as usize]})
309            .as_object()
310            .map(|m| m.to_owned());
311
312        // create a point
313        let p = Point {
314            id: PointId::Num(embedding.index),
315            vector,
316            payload,
317        };
318
319        points.push(p);
320    }
321
322    #[cfg(feature = "logging")]
323    info!(target: "stdout", "Number of points to be upserted: {}", points.len());
324
325    if let Err(e) = qdrant_client
326        .upsert_points(collection_name.as_ref(), points)
327        .await
328    {
329        let err_msg = format!("{}", e);
330
331        #[cfg(feature = "logging")]
332        error!(target: "stdout", "{}", &err_msg);
333
334        return Err(LlamaCoreError::Qdrant(err_msg));
335    }
336
337    Ok(())
338}
339
340async fn qdrant_search_similar_points(
341    qdrant_client: &qdrant::Qdrant,
342    collection_name: impl AsRef<str>,
343    query_vector: &[f32],
344    limit: usize,
345    score_threshold: Option<f32>,
346) -> Result<Vec<ScoredPoint>, LlamaCoreError> {
347    #[cfg(feature = "logging")]
348    info!(target: "stdout", "Search similar points from the qdrant instance.");
349
350    match qdrant_client
351        .search_points(
352            collection_name.as_ref(),
353            query_vector.to_vec(),
354            limit as u64,
355            score_threshold,
356        )
357        .await
358    {
359        Ok(search_result) => {
360            #[cfg(feature = "logging")]
361            info!(target: "stdout", "Number of similar points found: {}", search_result.len());
362
363            Ok(search_result)
364        }
365        Err(e) => {
366            let err_msg = e.to_string();
367
368            #[cfg(feature = "logging")]
369            error!(target: "stdout", "{}", &err_msg);
370
371            Err(LlamaCoreError::Qdrant(err_msg))
372        }
373    }
374}