Adaptive RAG

LLM
Adaptive RAG methods

This approach refines a naive RAG solution by first categorizing the given query into a category (factual, opinion…) and diverting the question thereafter to a specialized retriever.

The categorization happens via the LLM as well, you simply ask. In fact, all of the specialized retrievers are in essence a rephrasing of the question.

The specialized retrievers are:

The code works with Ollama locally and uses nomic-embed-text and qwen2.5:14b models. The text questioned is the “History of Mathematics” by Florian Cajoli on Gutenberg.

Import and setup

import os
import sys

from langchain.prompts import PromptTemplate

from termcolor import colored, cprint
from langchain.text_splitter import CharacterTextSplitter
from langchain.prompts import PromptTemplate

from langchain_core.retrievers import BaseRetriever
from typing import Dict, Any
from langchain.docstore.document import Document
from langchain_core.vectorstores import InMemoryVectorStore
from pydantic import BaseModel, Field


sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # 
from helper_functions import *
from evaluation.evalute_rag import *
 

Classifying the query

You simply ask the LLM what kinda query it is:

class categories_options(BaseModel):
        category: str = Field(description="The category of the query, the options are: Factual, Analytical, Opinion, or Contextual", example="Factual")


class QueryClassifier:
    def __init__(self):
        self.llm = get_llm()
        self.prompt = PromptTemplate(
            input_variables=["query"],
            template="Classify the following query into one of these categories: Factual, Analytical, Opinion, or Contextual. The default is Contextual.\nQuery: {query}\nCategory:"
        )
        self.chain = self.prompt | self.llm.with_structured_output(categories_options)


    def classify(self, query):
        return self.chain.invoke(query).category

Base retriever

All variations use the same basic retriever, which is a naive RAG retriever.

class BaseRetrievalStrategy:
    def __init__(self, db):         
        self.db = db
        self.llm = get_llm()

    def retrieve(self, query, k=4):
        return self.db.similarity_search(query, k=k)

Factual retriever

class relevant_score(BaseModel):
        score: float = Field(description="The relevance score of the document to the query", example=8.0)

class FactualRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=4):
        print()
        print(colored("============= Factual Retrieval Strategy =============", "green"))
        # Use LLM to enhance the query
        enhanced_query_prompt = PromptTemplate(
            input_variables=["query"],
            template="Enhance this factual query for better information retrieval: {query}"
        )
        query_chain = enhanced_query_prompt | self.llm
        enhanced_query = query_chain.invoke(query).content
        print(f'enhanced query: {enhanced_query}')

        # Retrieve documents using the enhanced query
        docs = self.db.similarity_search(enhanced_query, k=k*2)

        # Use LLM to rank the relevance of retrieved documents
        ranking_prompt = PromptTemplate(
            input_variables=["query", "doc"],
            template="On a scale of 1-10, how relevant is this document to the query: '{query}'?\nDocument: {doc}\nRelevance score:"
        )
        ranking_chain = ranking_prompt | self.llm.with_structured_output(relevant_score)

        ranked_docs = []
        print("ranking docs")
        for doc in docs:
            input_data = {"query": enhanced_query, "doc": doc.page_content}
            score = float(ranking_chain.invoke(input_data).score)
            ranked_docs.append((doc, score))

        # Sort by relevance score and return top k
        ranked_docs.sort(key=lambda x: x[1], reverse=True)
        return [doc for doc, _ in ranked_docs[:k]]

Analytical reriever

Analytical means in this context that the given question leads to the generation of sub-questions that are then answered by the retriever.


class SelectedIndices(BaseModel):
    indices: List[int] = Field(description="Indices of selected documents", example=[0, 1, 2, 3])

class SubQueries(BaseModel):
    sub_queries: List[str] = Field(description="List of sub-queries for comprehensive analysis", example=["What is the population of New York?", "What is the GDP of New York?"])

class AnalyticalRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=4):
        print()
        print(colored("============= Analytical Retrieval Strategy =============", "green"))
        # Use LLM to generate sub-queries for comprehensive analysis
        sub_queries_prompt = PromptTemplate(
            input_variables=["query", "k"],
            template="Generate {k} sub-questions for: {query}"
        )

        llm = get_llm()
        sub_queries_chain = sub_queries_prompt | llm.with_structured_output(SubQueries)

        input_data = {"query": query, "k": k}
        sub_queries = sub_queries_chain.invoke(input_data).sub_queries
        print(f'sub queries for comprehensive analysis: {sub_queries}')

        all_docs = []
        for sub_query in sub_queries:
            all_docs.extend(self.db.similarity_search(sub_query, k=2))

        # Use LLM to ensure diversity and relevance
        diversity_prompt = PromptTemplate(
            input_variables=["query", "docs", "k"],
            template="""Select the most diverse and relevant set of {k} documents for the query: '{query}'\nDocuments: {docs}\n
            Return only the indices of selected documents as a list of integers."""
        )
        diversity_chain = diversity_prompt | self.llm.with_structured_output(SelectedIndices)
        docs_text = "\n".join([f"{i}: {doc.page_content[:50]}..." for i, doc in enumerate(all_docs)])
        input_data = {"query": query, "docs": docs_text, "k": k}
        selected_indices_result = diversity_chain.invoke(input_data).indices
        print(f'selected diverse and relevant documents')
        
        return [all_docs[i] for i in selected_indices_result if i < len(all_docs)]

Opinion retriever

class OpinionRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=3):
        print()
        print(colored("============= Opinion Retrieval Strategy =============", "green"))        
        # Use LLM to identify potential viewpoints
        viewpoints_prompt = PromptTemplate(
            input_variables=["query", "k"],
            template="Identify {k} distinct viewpoints or perspectives on the topic: {query}"
        )
        viewpoints_chain = viewpoints_prompt | self.llm
        input_data = {"query": query, "k": k}
        viewpoints = viewpoints_chain.invoke(input_data).content.split('\n')
        print(f'viewpoints: {viewpoints}')

        all_docs = []
        for viewpoint in viewpoints:
            all_docs.extend(self.db.similarity_search(f"{query} {viewpoint}", k=2))

        # Use LLM to classify and select diverse opinions
        opinion_prompt = PromptTemplate(
            input_variables=["query", "docs", "k"],
            template="Classify these documents into distinct opinions on '{query}' and select the {k} most representative and diverse viewpoints:\nDocuments: {docs}\nSelected indices:"
        )
        opinion_chain = opinion_prompt | self.llm.with_structured_output(SelectedIndices)
        
        docs_text = "\n".join([f"{i}: {doc.page_content[:100]}..." for i, doc in enumerate(all_docs)])
        input_data = {"query": query, "docs": docs_text, "k": k}
        selected_indices = opinion_chain.invoke(input_data).indices
        print(f'selected diverse and relevant documents')
        
        return [all_docs[i] for i in selected_indices if isinstance(i, int) and i < len(all_docs)]

Contextual retriever

Contextual means that the question is rephrased based on the given chunks of text and the answer is then retrieved based on the rephrased question.

class ContextualRetrievalStrategy(BaseRetrievalStrategy):
    def retrieve(self, query, k=4, user_context=None):
        print()
        print(colored("============= Contextual Retrieval Strategy =============", "green"))
        
        # Use LLM to incorporate user context into the query
        context_prompt = PromptTemplate(
            input_variables=["query", "context"],
            template="Given the user context: {context}\nReformulate the query to best address the user's needs: {query}"
        )
        context_chain = context_prompt | self.llm
        input_data = {"query": query, "context": user_context or "No specific context provided"}
        contextualized_query = context_chain.invoke(input_data).content
        print(f'contextualized query: {contextualized_query}')

        # Retrieve documents using the contextualized query
        docs = self.db.similarity_search(contextualized_query, k=k*2)

        # Use LLM to rank the relevance of retrieved documents considering the user context
        ranking_prompt = PromptTemplate(
            input_variables=["query", "context", "doc"],
            template="Given the query: '{query}' and user context: '{context}', rate the relevance of this document on a scale of 1-10:\nDocument: {doc}\nRelevance score:"
        )
        ranking_chain = ranking_prompt | self.llm.with_structured_output(relevant_score)
        print("ranking docs")

        ranked_docs = []
        for doc in docs:
            input_data = {"query": contextualized_query, "context": user_context or "No specific context provided", "doc": doc.page_content}
            score = float(ranking_chain.invoke(input_data).score)
            ranked_docs.append((doc, score))


        # Sort by relevance score and return top k
        ranked_docs.sort(key=lambda x: x[1], reverse=True)

        return [doc for doc, _ in ranked_docs[:k]]

Adapive retriever

class AdaptiveRetriever:
    def __init__(self, db):
        self.classifier = QueryClassifier()
        self.strategies = {
            "Factual": FactualRetrievalStrategy(db),
            "Analytical": AnalyticalRetrievalStrategy(db),
            "Opinion": OpinionRetrievalStrategy(db),
            "Contextual": ContextualRetrievalStrategy(db)
        }

    def get_relevant_documents(self, query: str) -> List[Document]:
        category = self.classifier.classify(query)
        strategy = self.strategies[category]
        return strategy.retrieve(query)

Pydantic retriever

class PydanticAdaptiveRetriever(BaseRetriever):
    adaptive_retriever: AdaptiveRetriever = Field(exclude=True)

    class Config:
        arbitrary_types_allowed = True

    def get_relevant_documents(self, query: str) -> List[Document]:
        return self.adaptive_retriever.get_relevant_documents(query)

    async def aget_relevant_documents(self, query: str) -> List[Document]:
        return self.get_relevant_documents(query)
/var/folders/0p/dt9_dywj6rnc0wzps07dg31m0000gn/T/ipykernel_10374/4016186069.py:1: DeprecationWarning: Retrievers must implement abstract `_get_relevant_documents` method instead of `get_relevant_documents`
  class PydanticAdaptiveRetriever(BaseRetriever):
/var/folders/0p/dt9_dywj6rnc0wzps07dg31m0000gn/T/ipykernel_10374/4016186069.py:1: DeprecationWarning: Retrievers must implement abstract `_aget_relevant_documents` method instead of `aget_relevant_documents`
  class PydanticAdaptiveRetriever(BaseRetriever):

All in one retriever


class AdaptiveRAG:
    def __init__(self, texts: List[str]):
        adaptive_retriever = AdaptiveRetriever(texts)
        self.retriever = PydanticAdaptiveRetriever(adaptive_retriever=adaptive_retriever)
        self.llm = get_llm()
        
        # Create a custom prompt
        prompt_template = """Use the following pieces of context to answer the question at the end. 
        If you don't know the answer, just say that you don't know, don't try to make up an answer.

        {context}

        Question: {question}
        Answer:"""
        prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
        
        # Create the LLM chain
        self.llm_chain = prompt | self.llm
        
      

    def answer(self, query: str) -> str:
        print()
        print(cprint(f"Query:{query}", "white","on_blue"))
        docs = self.retriever.get_relevant_documents(query)
        input_data = {"context": "\n".join([doc.page_content for doc in docs]), "question": query}
        return self.llm_chain.invoke(input_data)

Testing it out

store = InMemoryVectorStore.load("../data/HistoryOfMathematics.store", embedding=get_embedding())
rag_system = AdaptiveRAG(store)
factual_result = rag_system.answer("The process of Antiphon and Bryson gave rise to what?").content
print(f"Answer: {factual_result}")

analytical_result = rag_system.answer("The Romans employed three different kinds of arithmetical calculations, which ones?").content
print(f"Answer: {analytical_result}")

opinion_result = rag_system.answer("Who is the most genius mathematician according to you?").content
print(f"Answer: {opinion_result}")

contextual_result = rag_system.answer("Who worked on the Inverse Problems of Tangents?").content
print(f"Answer: {contextual_result}")
Query:Who worked on the Inverse Problems of Tangents?

None



============= Factual Retrieval Strategy =============

enhanced query: To improve the precision and effectiveness of your search, you might want to rephrase your query as follows:



"Who were the key mathematicians or researchers involved in the development of inverse problems related to tangents?"



This revision specifies that you are interested in identifying specific individuals who have contributed to this area of mathematics. It also clarifies that you are looking for work on "inverse problems," which is a more precise term than just "Inverse Problems of Tangents." If you can provide more context or specify the time period, it would further refine your search results.

ranking docs

Answer: Leibniz worked on the Inverse Problems of Tangents, as mentioned in the context provided.