llama_core/
rag.rs

1//! Define APIs for RAG operations.
2
3use crate::{embeddings::embeddings, error::LlamaCoreError, running_mode, RunningMode};
4use endpoints::{
5    embeddings::{EmbeddingObject, EmbeddingRequest, EmbeddingsResponse, InputText},
6    rag::vector_search::{DataFrom, RagScoredPoint, RetrieveObject},
7};
8use qdrant::*;
9use serde_json::Value;
10use std::collections::HashSet;
11
12/// Convert document chunks to embeddings.
13///
14/// # Arguments
15///
16/// * `embedding_request` - A reference to an `EmbeddingRequest` object.
17///
18/// # Returns
19///
20/// Name of the Qdrant collection if successful.
21pub async fn rag_doc_chunks_to_embeddings(
22    embedding_request: &EmbeddingRequest,
23) -> Result<EmbeddingsResponse, LlamaCoreError> {
24    #[cfg(feature = "logging")]
25    info!(target: "stdout", "Convert document chunks to embeddings.");
26
27    let running_mode = running_mode()?;
28    if running_mode != RunningMode::RAG {
29        let err_msg =
30            format!("Creating knowledge base is not supported in the {running_mode} mode.");
31
32        #[cfg(feature = "logging")]
33        error!(target: "stdout", "{}", &err_msg);
34
35        return Err(LlamaCoreError::Operation(err_msg));
36    }
37
38    let qdrant_url = match embedding_request.vdb_server_url.as_deref() {
39        Some(url) => url.to_string(),
40        None => {
41            let err_msg = "The VectorDB server URL is not provided.";
42
43            #[cfg(feature = "logging")]
44            error!(target: "stdout", "{}", &err_msg);
45
46            return Err(LlamaCoreError::Operation(err_msg.into()));
47        }
48    };
49    let qdrant_collection_name = match embedding_request.vdb_collection_name.as_deref() {
50        Some(name) => name.to_string(),
51        None => {
52            let err_msg = "The VectorDB collection name is not provided.";
53
54            #[cfg(feature = "logging")]
55            error!(target: "stdout", "{}", &err_msg);
56
57            return Err(LlamaCoreError::Operation(err_msg.into()));
58        }
59    };
60
61    #[cfg(feature = "logging")]
62    info!(target: "stdout", "Compute embeddings for document chunks.");
63
64    #[cfg(feature = "logging")]
65    if let Ok(request_str) = serde_json::to_string(&embedding_request) {
66        debug!(target: "stdout", "Embedding request: {request_str}");
67    }
68
69    // compute embeddings for the document
70    let embeddings_response = embeddings(embedding_request).await?;
71    let embeddings = embeddings_response.data.as_slice();
72    let dim = embeddings[0].embedding.len();
73
74    // create a Qdrant client
75    let mut qdrant_client = qdrant::Qdrant::new_with_url(qdrant_url);
76
77    // set the API key if provided
78    if let Some(key) = embedding_request.vdb_api_key.as_deref() {
79        if !key.is_empty() {
80            #[cfg(feature = "logging")]
81            debug!(target: "stdout", "Set the API key for the VectorDB server.");
82
83            qdrant_client.set_api_key(key);
84        }
85    }
86
87    // create a collection
88    qdrant_create_collection(&qdrant_client, &qdrant_collection_name, dim).await?;
89
90    let chunks = match &embedding_request.input {
91        InputText::String(text) => vec![text.clone()],
92        InputText::ArrayOfStrings(texts) => texts.clone(),
93        InputText::ArrayOfTokens(tokens) => tokens.iter().map(|t| t.to_string()).collect(),
94        InputText::ArrayOfTokenArrays(token_arrays) => token_arrays
95            .iter()
96            .map(|tokens| tokens.iter().map(|t| t.to_string()).collect())
97            .collect(),
98    };
99
100    // create and upsert points
101    qdrant_persist_embeddings(
102        &qdrant_client,
103        &qdrant_collection_name,
104        embeddings,
105        chunks.as_slice(),
106    )
107    .await?;
108
109    Ok(embeddings_response)
110}
111
112/// Convert a query to embeddings.
113///
114/// # Arguments
115///
116/// * `embedding_request` - A reference to an `EmbeddingRequest` object.
117pub async fn rag_query_to_embeddings(
118    embedding_request: &EmbeddingRequest,
119) -> Result<EmbeddingsResponse, LlamaCoreError> {
120    #[cfg(feature = "logging")]
121    info!(target: "stdout", "Compute embeddings for the user query.");
122
123    let running_mode = running_mode()?;
124    if running_mode != RunningMode::RAG {
125        let err_msg = format!("The RAG query is not supported in the {running_mode} mode.",);
126
127        #[cfg(feature = "logging")]
128        error!(target: "stdout", "{}", &err_msg);
129
130        return Err(LlamaCoreError::Operation(err_msg));
131    }
132
133    embeddings(embedding_request).await
134}
135
136/// Retrieve similar points from the Qdrant server using the query embedding
137///
138/// # Arguments
139///
140/// * `query_embedding` - A reference to a query embedding.
141///
142/// * `qdrant_url` - URL of the Qdrant server.
143///
144/// * `qdrant_collection_name` - Name of the Qdrant collection to be created.
145///
146/// * `limit` - Number of retrieved results.
147///
148/// * `score_threshold` - The minimum score of the retrieved results.
149pub async fn rag_retrieve_context(
150    query_embedding: &[f32],
151    vdb_server_url: impl AsRef<str>,
152    vdb_collection_name: impl AsRef<str>,
153    limit: usize,
154    score_threshold: Option<f32>,
155    vdb_api_key: Option<String>,
156) -> Result<RetrieveObject, LlamaCoreError> {
157    #[cfg(feature = "logging")]
158    {
159        info!(target: "stdout", "Retrieve context.");
160
161        info!(target: "stdout", "qdrant_url: {}, qdrant_collection_name: {}, limit: {}, score_threshold: {}", vdb_server_url.as_ref(), vdb_collection_name.as_ref(), limit, score_threshold.unwrap_or_default());
162    }
163
164    let running_mode = running_mode()?;
165    if running_mode != RunningMode::RAG {
166        let err_msg = format!("The context retrieval is not supported in the {running_mode} mode.");
167
168        #[cfg(feature = "logging")]
169        error!(target: "stdout", "{}", &err_msg);
170
171        return Err(LlamaCoreError::Operation(err_msg));
172    }
173
174    // create a Qdrant client
175    let mut qdrant_client = qdrant::Qdrant::new_with_url(vdb_server_url.as_ref().to_string());
176
177    // set the API key if provided
178    if let Some(key) = vdb_api_key.as_deref() {
179        if !key.is_empty() {
180            #[cfg(feature = "logging")]
181            debug!(target: "stdout", "Set the API key for the VectorDB server.");
182
183            qdrant_client.set_api_key(key);
184        }
185    }
186
187    // search for similar points
188    let scored_points = match qdrant_search_similar_points(
189        &qdrant_client,
190        vdb_collection_name.as_ref(),
191        query_embedding,
192        limit,
193        score_threshold,
194    )
195    .await
196    {
197        Ok(points) => points,
198        Err(e) => {
199            #[cfg(feature = "logging")]
200            error!(target: "stdout", "{e}");
201
202            return Err(e);
203        }
204    };
205
206    #[cfg(feature = "logging")]
207    info!(target: "stdout", "remove duplicates from {} scored points", scored_points.len());
208
209    // remove duplicates, which have the same source
210    let mut seen = HashSet::new();
211    let unique_scored_points: Vec<ScoredPoint> = scored_points
212        .into_iter()
213        .filter(|point| {
214            seen.insert(
215                point
216                    .payload
217                    .as_ref()
218                    .unwrap()
219                    .get("source")
220                    .unwrap()
221                    .to_string(),
222            )
223        })
224        .collect();
225
226    #[cfg(feature = "logging")]
227    info!(target: "stdout", "number of unique scored points: {}", unique_scored_points.len());
228
229    let ro = match unique_scored_points.is_empty() {
230        true => RetrieveObject {
231            points: None,
232            limit,
233            score_threshold: score_threshold.unwrap_or(0.0),
234        },
235        false => {
236            let mut points: Vec<RagScoredPoint> = vec![];
237            for point in unique_scored_points.iter() {
238                if let Some(payload) = &point.payload {
239                    if let Some(source) = payload.get("source").and_then(Value::as_str) {
240                        points.push(RagScoredPoint {
241                            source: source.to_string(),
242                            score: point.score as f64,
243                            from: DataFrom::VectorSearch,
244                        })
245                    }
246
247                    // For debugging purpose, log the optional search field if it exists
248                    #[cfg(feature = "logging")]
249                    if let Some(search) = payload.get("search").and_then(Value::as_str) {
250                        info!(target: "stdout", "search: {search}");
251                    }
252                }
253            }
254
255            RetrieveObject {
256                points: Some(points),
257                limit,
258                score_threshold: score_threshold.unwrap_or(0.0),
259            }
260        }
261    };
262
263    Ok(ro)
264}
265
266async fn qdrant_create_collection(
267    qdrant_client: &qdrant::Qdrant,
268    collection_name: impl AsRef<str>,
269    dim: usize,
270) -> Result<(), LlamaCoreError> {
271    #[cfg(feature = "logging")]
272    info!(target: "stdout", "Create a Qdrant collection named {} of {} dimensions.", collection_name.as_ref(), dim);
273
274    if let Err(e) = qdrant_client
275        .create_collection(collection_name.as_ref(), dim as u32)
276        .await
277    {
278        let err_msg = e.to_string();
279
280        #[cfg(feature = "logging")]
281        error!(target: "stdout", "{}", &err_msg);
282
283        return Err(LlamaCoreError::Qdrant(err_msg));
284    }
285
286    Ok(())
287}
288
289async fn qdrant_persist_embeddings(
290    qdrant_client: &qdrant::Qdrant,
291    collection_name: impl AsRef<str>,
292    embeddings: &[EmbeddingObject],
293    chunks: &[String],
294) -> Result<(), LlamaCoreError> {
295    #[cfg(feature = "logging")]
296    info!(target: "stdout", "Persist embeddings to the Qdrant instance.");
297
298    let mut points = Vec::<Point>::new();
299    for embedding in embeddings {
300        // convert the embedding to a vector
301        let vector: Vec<_> = embedding.embedding.iter().map(|x| *x as f32).collect();
302
303        // create a payload
304        let payload = serde_json::json!({"source": chunks[embedding.index as usize]})
305            .as_object()
306            .map(|m| m.to_owned());
307
308        // create a point
309        let p = Point {
310            id: PointId::Num(embedding.index),
311            vector,
312            payload,
313        };
314
315        points.push(p);
316    }
317
318    #[cfg(feature = "logging")]
319    info!(target: "stdout", "Number of points to be upserted: {}", points.len());
320
321    if let Err(e) = qdrant_client
322        .upsert_points(collection_name.as_ref(), points)
323        .await
324    {
325        let err_msg = format!("{e}");
326
327        #[cfg(feature = "logging")]
328        error!(target: "stdout", "{}", &err_msg);
329
330        return Err(LlamaCoreError::Qdrant(err_msg));
331    }
332
333    Ok(())
334}
335
336async fn qdrant_search_similar_points(
337    qdrant_client: &qdrant::Qdrant,
338    collection_name: impl AsRef<str>,
339    query_vector: &[f32],
340    limit: usize,
341    score_threshold: Option<f32>,
342) -> Result<Vec<ScoredPoint>, LlamaCoreError> {
343    #[cfg(feature = "logging")]
344    info!(target: "stdout", "Search similar points from the qdrant instance.");
345
346    match qdrant_client
347        .search_points(
348            collection_name.as_ref(),
349            query_vector.to_vec(),
350            limit as u64,
351            score_threshold,
352        )
353        .await
354    {
355        Ok(search_result) => {
356            #[cfg(feature = "logging")]
357            info!(target: "stdout", "Number of similar points found: {}", search_result.len());
358
359            Ok(search_result)
360        }
361        Err(e) => {
362            let err_msg = e.to_string();
363
364            #[cfg(feature = "logging")]
365            error!(target: "stdout", "{}", &err_msg);
366
367            Err(LlamaCoreError::Qdrant(err_msg))
368        }
369    }
370}