import os
import sys
from langchain.prompts import PromptTemplate
from termcolor import colored, cprint
from langchain.text_splitter import CharacterTextSplitter
from langchain.prompts import PromptTemplate
from langchain_core.retrievers import BaseRetriever
from typing import Dict, Any
from langchain.docstore.document import Document
from langchain_core.vectorstores import InMemoryVectorStore
from pydantic import BaseModel, Field
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) #
from helper_functions import *
from evaluation.evalute_rag import *
Adaptive RAG
This approach refines a naive RAG solution by first categorizing the given query into a category (factual, opinion…) and diverting the question thereafter to a specialized retriever.
The categorization happens via the LLM as well, you simply ask. In fact, all of the specialized retrievers are in essence a rephrasing of the question.
The specialized retrievers are:
- factual: the query is rephrased by asking the LLM “Enhance this factual query for better information retrieval: …”
- opinion: the query is turned into multiple queries to take into account multuple aspects of the question “Identify 4 distinct viewpoints or perspectives on the topic: …”
- analytical: sub-queries are generated based on “Generate 5 sub-questions for: …”
- contextual: this rephrases the query on the basis of the context “Given the user context: {context}the query to best address the user’s needs: …”
The code works with Ollama locally and uses nomic-embed-text and qwen2.5:14b models. The text questioned is the “History of Mathematics” by Florian Cajoli on Gutenberg.
Import and setup
Classifying the query
You simply ask the LLM what kinda query it is:
class categories_options(BaseModel):
category: str = Field(description="The category of the query, the options are: Factual, Analytical, Opinion, or Contextual", example="Factual")
class QueryClassifier:
def __init__(self):
self.llm = get_llm()
self.prompt = PromptTemplate(
input_variables=["query"],
template="Classify the following query into one of these categories: Factual, Analytical, Opinion, or Contextual. The default is Contextual.\nQuery: {query}\nCategory:"
)
self.chain = self.prompt | self.llm.with_structured_output(categories_options)
def classify(self, query):
return self.chain.invoke(query).categoryBase retriever
All variations use the same basic retriever, which is a naive RAG retriever.
class BaseRetrievalStrategy:
def __init__(self, db):
self.db = db
self.llm = get_llm()
def retrieve(self, query, k=4):
return self.db.similarity_search(query, k=k)Factual retriever
class relevant_score(BaseModel):
score: float = Field(description="The relevance score of the document to the query", example=8.0)
class FactualRetrievalStrategy(BaseRetrievalStrategy):
def retrieve(self, query, k=4):
print()
print(colored("============= Factual Retrieval Strategy =============", "green"))
# Use LLM to enhance the query
enhanced_query_prompt = PromptTemplate(
input_variables=["query"],
template="Enhance this factual query for better information retrieval: {query}"
)
query_chain = enhanced_query_prompt | self.llm
enhanced_query = query_chain.invoke(query).content
print(f'enhanced query: {enhanced_query}')
# Retrieve documents using the enhanced query
docs = self.db.similarity_search(enhanced_query, k=k*2)
# Use LLM to rank the relevance of retrieved documents
ranking_prompt = PromptTemplate(
input_variables=["query", "doc"],
template="On a scale of 1-10, how relevant is this document to the query: '{query}'?\nDocument: {doc}\nRelevance score:"
)
ranking_chain = ranking_prompt | self.llm.with_structured_output(relevant_score)
ranked_docs = []
print("ranking docs")
for doc in docs:
input_data = {"query": enhanced_query, "doc": doc.page_content}
score = float(ranking_chain.invoke(input_data).score)
ranked_docs.append((doc, score))
# Sort by relevance score and return top k
ranked_docs.sort(key=lambda x: x[1], reverse=True)
return [doc for doc, _ in ranked_docs[:k]]Analytical reriever
Analytical means in this context that the given question leads to the generation of sub-questions that are then answered by the retriever.
class SelectedIndices(BaseModel):
indices: List[int] = Field(description="Indices of selected documents", example=[0, 1, 2, 3])
class SubQueries(BaseModel):
sub_queries: List[str] = Field(description="List of sub-queries for comprehensive analysis", example=["What is the population of New York?", "What is the GDP of New York?"])
class AnalyticalRetrievalStrategy(BaseRetrievalStrategy):
def retrieve(self, query, k=4):
print()
print(colored("============= Analytical Retrieval Strategy =============", "green"))
# Use LLM to generate sub-queries for comprehensive analysis
sub_queries_prompt = PromptTemplate(
input_variables=["query", "k"],
template="Generate {k} sub-questions for: {query}"
)
llm = get_llm()
sub_queries_chain = sub_queries_prompt | llm.with_structured_output(SubQueries)
input_data = {"query": query, "k": k}
sub_queries = sub_queries_chain.invoke(input_data).sub_queries
print(f'sub queries for comprehensive analysis: {sub_queries}')
all_docs = []
for sub_query in sub_queries:
all_docs.extend(self.db.similarity_search(sub_query, k=2))
# Use LLM to ensure diversity and relevance
diversity_prompt = PromptTemplate(
input_variables=["query", "docs", "k"],
template="""Select the most diverse and relevant set of {k} documents for the query: '{query}'\nDocuments: {docs}\n
Return only the indices of selected documents as a list of integers."""
)
diversity_chain = diversity_prompt | self.llm.with_structured_output(SelectedIndices)
docs_text = "\n".join([f"{i}: {doc.page_content[:50]}..." for i, doc in enumerate(all_docs)])
input_data = {"query": query, "docs": docs_text, "k": k}
selected_indices_result = diversity_chain.invoke(input_data).indices
print(f'selected diverse and relevant documents')
return [all_docs[i] for i in selected_indices_result if i < len(all_docs)]Opinion retriever
class OpinionRetrievalStrategy(BaseRetrievalStrategy):
def retrieve(self, query, k=3):
print()
print(colored("============= Opinion Retrieval Strategy =============", "green"))
# Use LLM to identify potential viewpoints
viewpoints_prompt = PromptTemplate(
input_variables=["query", "k"],
template="Identify {k} distinct viewpoints or perspectives on the topic: {query}"
)
viewpoints_chain = viewpoints_prompt | self.llm
input_data = {"query": query, "k": k}
viewpoints = viewpoints_chain.invoke(input_data).content.split('\n')
print(f'viewpoints: {viewpoints}')
all_docs = []
for viewpoint in viewpoints:
all_docs.extend(self.db.similarity_search(f"{query} {viewpoint}", k=2))
# Use LLM to classify and select diverse opinions
opinion_prompt = PromptTemplate(
input_variables=["query", "docs", "k"],
template="Classify these documents into distinct opinions on '{query}' and select the {k} most representative and diverse viewpoints:\nDocuments: {docs}\nSelected indices:"
)
opinion_chain = opinion_prompt | self.llm.with_structured_output(SelectedIndices)
docs_text = "\n".join([f"{i}: {doc.page_content[:100]}..." for i, doc in enumerate(all_docs)])
input_data = {"query": query, "docs": docs_text, "k": k}
selected_indices = opinion_chain.invoke(input_data).indices
print(f'selected diverse and relevant documents')
return [all_docs[i] for i in selected_indices if isinstance(i, int) and i < len(all_docs)]Contextual retriever
Contextual means that the question is rephrased based on the given chunks of text and the answer is then retrieved based on the rephrased question.
class ContextualRetrievalStrategy(BaseRetrievalStrategy):
def retrieve(self, query, k=4, user_context=None):
print()
print(colored("============= Contextual Retrieval Strategy =============", "green"))
# Use LLM to incorporate user context into the query
context_prompt = PromptTemplate(
input_variables=["query", "context"],
template="Given the user context: {context}\nReformulate the query to best address the user's needs: {query}"
)
context_chain = context_prompt | self.llm
input_data = {"query": query, "context": user_context or "No specific context provided"}
contextualized_query = context_chain.invoke(input_data).content
print(f'contextualized query: {contextualized_query}')
# Retrieve documents using the contextualized query
docs = self.db.similarity_search(contextualized_query, k=k*2)
# Use LLM to rank the relevance of retrieved documents considering the user context
ranking_prompt = PromptTemplate(
input_variables=["query", "context", "doc"],
template="Given the query: '{query}' and user context: '{context}', rate the relevance of this document on a scale of 1-10:\nDocument: {doc}\nRelevance score:"
)
ranking_chain = ranking_prompt | self.llm.with_structured_output(relevant_score)
print("ranking docs")
ranked_docs = []
for doc in docs:
input_data = {"query": contextualized_query, "context": user_context or "No specific context provided", "doc": doc.page_content}
score = float(ranking_chain.invoke(input_data).score)
ranked_docs.append((doc, score))
# Sort by relevance score and return top k
ranked_docs.sort(key=lambda x: x[1], reverse=True)
return [doc for doc, _ in ranked_docs[:k]]Adapive retriever
class AdaptiveRetriever:
def __init__(self, db):
self.classifier = QueryClassifier()
self.strategies = {
"Factual": FactualRetrievalStrategy(db),
"Analytical": AnalyticalRetrievalStrategy(db),
"Opinion": OpinionRetrievalStrategy(db),
"Contextual": ContextualRetrievalStrategy(db)
}
def get_relevant_documents(self, query: str) -> List[Document]:
category = self.classifier.classify(query)
strategy = self.strategies[category]
return strategy.retrieve(query)Pydantic retriever
class PydanticAdaptiveRetriever(BaseRetriever):
adaptive_retriever: AdaptiveRetriever = Field(exclude=True)
class Config:
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
return self.adaptive_retriever.get_relevant_documents(query)
async def aget_relevant_documents(self, query: str) -> List[Document]:
return self.get_relevant_documents(query)/var/folders/0p/dt9_dywj6rnc0wzps07dg31m0000gn/T/ipykernel_10374/4016186069.py:1: DeprecationWarning: Retrievers must implement abstract `_get_relevant_documents` method instead of `get_relevant_documents`
class PydanticAdaptiveRetriever(BaseRetriever):
/var/folders/0p/dt9_dywj6rnc0wzps07dg31m0000gn/T/ipykernel_10374/4016186069.py:1: DeprecationWarning: Retrievers must implement abstract `_aget_relevant_documents` method instead of `aget_relevant_documents`
class PydanticAdaptiveRetriever(BaseRetriever):
All in one retriever
class AdaptiveRAG:
def __init__(self, texts: List[str]):
adaptive_retriever = AdaptiveRetriever(texts)
self.retriever = PydanticAdaptiveRetriever(adaptive_retriever=adaptive_retriever)
self.llm = get_llm()
# Create a custom prompt
prompt_template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
Answer:"""
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
# Create the LLM chain
self.llm_chain = prompt | self.llm
def answer(self, query: str) -> str:
print()
print(cprint(f"Query:{query}", "white","on_blue"))
docs = self.retriever.get_relevant_documents(query)
input_data = {"context": "\n".join([doc.page_content for doc in docs]), "question": query}
return self.llm_chain.invoke(input_data)Testing it out
store = InMemoryVectorStore.load("../data/HistoryOfMathematics.store", embedding=get_embedding())
rag_system = AdaptiveRAG(store)factual_result = rag_system.answer("The process of Antiphon and Bryson gave rise to what?").content
print(f"Answer: {factual_result}")
analytical_result = rag_system.answer("The Romans employed three different kinds of arithmetical calculations, which ones?").content
print(f"Answer: {analytical_result}")
opinion_result = rag_system.answer("Who is the most genius mathematician according to you?").content
print(f"Answer: {opinion_result}")
contextual_result = rag_system.answer("Who worked on the Inverse Problems of Tangents?").content
print(f"Answer: {contextual_result}")Query:Who worked on the Inverse Problems of Tangents? None ============= Factual Retrieval Strategy ============= enhanced query: To improve the precision and effectiveness of your search, you might want to rephrase your query as follows: "Who were the key mathematicians or researchers involved in the development of inverse problems related to tangents?" This revision specifies that you are interested in identifying specific individuals who have contributed to this area of mathematics. It also clarifies that you are looking for work on "inverse problems," which is a more precise term than just "Inverse Problems of Tangents." If you can provide more context or specify the time period, it would further refine your search results. ranking docs Answer: Leibniz worked on the Inverse Problems of Tangents, as mentioned in the context provided.