diff options
author | Aditya <bluenerd@protonmail.com> | 2025-02-10 22:20:49 +0530 |
---|---|---|
committer | Aditya <bluenerd@protonmail.com> | 2025-02-10 22:20:49 +0530 |
commit | 93ee9c739c9dbe6ce281f544d428df807d476964 (patch) | |
tree | a56f69d0c4e38f467655bae5e1e991953c1010f2 |
initial commit
-rw-r--r-- | .gitignore | 182 | ||||
-rw-r--r-- | Makefile | 34 | ||||
-rw-r--r-- | app.py | 250 | ||||
-rw-r--r-- | requirements/requirements-dev.txt | 1 | ||||
-rw-r--r-- | requirements/requirements.txt | 7 |
5 files changed, 474 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ddd0f29 --- /dev/null +++ b/.gitignore @@ -0,0 +1,182 @@ +# Created by https://www.toptal.com/developers/gitignore/api/python +# Edit at https://www.toptal.com/developers/gitignore?templates=python + + +### Python ### +# CUSTOM +*.sqlite3 +*.bin +demo-rag-chroma/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +# End of https://www.toptal.com/developers/gitignore/api/python diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..fd3f40b --- /dev/null +++ b/Makefile @@ -0,0 +1,34 @@ +SHELL :=/bin/bash + +.PHONY: clean check setup +.DEFAULT_GOAL=help +VENV_DIR = .venv +PYTHON_VERSION = python3.11 + +check: # Ruff check + @ruff check . + @echo "ā
Check complete!" + +fix: # Fix auto-fixable linting issues + @ruff check app.py --fix + +clean: # Clean temporary files + @rm -rf __pycache__ .pytest_cache + @find . -name '*.pyc' -exec rm -r {} + + @find . -name '__pycache__' -exec rm -r {} + + @rm -rf build dist + @find . -name '*.egg-info' -type d -exec rm -r {} + + +run: # Run the application + @streamlit run app.py + +setup: # Initial project setup + @echo "Creating virtual env at: $(VENV_DIR)"s + @$(PYTHON_VERSION) -m venv $(VENV_DIR) + @echo "Installing dependencies..." + @source $(VENV_DIR)/bin/activate && pip install -r requirements/requirements-dev.txt && pip install -r requirements/requirements.txt + @echo -e "\nā
Done.\nš Run the following commands to get started:\n\n ā”ļø source $(VENV_DIR)/bin/activate\n ā”ļø make run\n" + + +help: # Show this help + @egrep -h '\s#\s' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?# "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}' @@ -0,0 +1,250 @@ +import os +import tempfile + +import chromadb +import ollama +import streamlit as st +from chromadb.utils.embedding_functions.ollama_embedding_function import ( + OllamaEmbeddingFunction, +) +from langchain_community.document_loaders import PyMuPDFLoader +from langchain_core.documents import Document +from langchain_text_splitters import RecursiveCharacterTextSplitter +from sentence_transformers import CrossEncoder +from streamlit.runtime.uploaded_file_manager import UploadedFile + +system_prompt = """ +You are an AI assistant tasked with providing detailed answers based solely on the given context. Your goal is to analyze the information provided and formulate a comprehensive, well-structured response to the question. + +context will be passed as "Context:" +user question will be passed as "Question:" + +To answer the question: +1. Thoroughly analyze the context, identifying key information relevant to the question. +2. Organize your thoughts and plan your response to ensure a logical flow of information. +3. Formulate a detailed answer that directly addresses the question, using only the information provided in the context. +4. Ensure your answer is comprehensive, covering all relevant aspects found in the context. +5. If the context doesn't contain sufficient information to fully answer the question, state this clearly in your response. + +Format your response as follows: +1. Use clear, concise language. +2. Organize your answer into paragraphs for readability. +3. Use bullet points or numbered lists where appropriate to break down complex information. +4. If relevant, include any headings or subheadings to structure your response. +5. Ensure proper grammar, punctuation, and spelling throughout your answer. + +Important: Base your entire response solely on the information provided in the context. Do not include any external knowledge or assumptions not present in the given text. +""" + + +def process_document(uploaded_file: UploadedFile) -> list[Document]: + """Processes an uploaded PDF file by converting it to text chunks. + + Takes an uploaded PDF file, saves it temporarily, loads and splits the content + into text chunks using recursive character splitting. + + Args: + uploaded_file: A Streamlit UploadedFile object containing the PDF file + + Returns: + A list of Document objects containing the chunked text from the PDF + + Raises: + IOError: If there are issues reading/writing the temporary file + """ + # Store uploaded file as a temp file + temp_file = tempfile.NamedTemporaryFile("wb", suffix=".pdf", delete=False) + temp_file.write(uploaded_file.read()) + + loader = PyMuPDFLoader(temp_file.name) + docs = loader.load() + os.unlink(temp_file.name) # Delete temp file + + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=400, + chunk_overlap=100, + separators=["\n\n", "\n", ".", "?", "!", " ", ""], + ) + return text_splitter.split_documents(docs) + + +def get_vector_collection() -> chromadb.Collection: + """Gets or creates a ChromaDB collection for vector storage. + + Creates an Ollama embedding function using the nomic-embed-text model and initializes + a persistent ChromaDB client. Returns a collection that can be used to store and + query document embeddings. + + Returns: + chromadb.Collection: A ChromaDB collection configured with the Ollama embedding + function and cosine similarity space. + """ + ollama_ef = OllamaEmbeddingFunction( + url="http://localhost:11434/api/embeddings", + model_name="nomic-embed-text:latest", + ) + + chroma_client = chromadb.PersistentClient(path="./chromadb") + return chroma_client.get_or_create_collection( + name="rag_app", + embedding_function=ollama_ef, + metadata={"hnsw:space": "cosine"}, + ) + + +def add_to_vector_collection(all_splits: list[Document], file_name: str): + """Adds document splits to a vector collection for semantic search. + + Takes a list of document splits and adds them to a ChromaDB vector collection + along with their metadata and unique IDs based on the filename. + + Args: + all_splits: List of Document objects containing text chunks and metadata + file_name: String identifier used to generate unique IDs for the chunks + + Returns: + None. Displays a success message via Streamlit when complete. + + Raises: + ChromaDBError: If there are issues upserting documents to the collection + """ + collection = get_vector_collection() + documents, metadatas, ids = [], [], [] + + for idx, split in enumerate(all_splits): + documents.append(split.page_content) + metadatas.append(split.metadata) + ids.append(f"{file_name}_{idx}") + + collection.upsert( + documents=documents, + metadatas=metadatas, + ids=ids, + ) + st.success("Data added to the vector store!") + + +def query_collection(prompt: str, n_results: int = 10): + """Queries the vector collection with a given prompt to retrieve relevant documents. + + Args: + prompt: The search query text to find relevant documents. + n_results: Maximum number of results to return. Defaults to 10. + + Returns: + dict: Query results containing documents, distances and metadata from the collection. + + Raises: + ChromaDBError: If there are issues querying the collection. + """ + collection = get_vector_collection() + results = collection.query(query_texts=[prompt], n_results=n_results) + return results + + +def call_llm(context: str, prompt: str): + """Calls the language model with context and prompt to generate a response. + + Uses Ollama to stream responses from a language model by providing context and a + question prompt. The model uses a system prompt to format and ground its responses appropriately. + + Args: + context: String containing the relevant context for answering the question + prompt: String containing the user's question + + Yields: + String chunks of the generated response as they become available from the model + + Raises: + OllamaError: If there are issues communicating with the Ollama API + """ + response = ollama.chat( + model="granite3-dense:latest", + stream=True, + messages=[ + { + "role": "system", + "content": system_prompt, + }, + { + "role": "user", + "content": f"Context: {context}, Question: {prompt}", + }, + ], + ) + for chunk in response: + if chunk["done"] is False: + yield chunk["message"]["content"] + else: + break + + +def re_rank_cross_encoders(documents: list[str]) -> tuple[str, list[int]]: + """Re-ranks documents using a cross-encoder model for more accurate relevance scoring. + + Uses the MS MARCO MiniLM cross-encoder model to re-rank the input documents based on + their relevance to the query prompt. Returns the concatenated text of the top 3 most + relevant documents along with their indices. + + Args: + documents: List of document strings to be re-ranked. + + Returns: + tuple: A tuple containing: + - relevant_text (str): Concatenated text from the top 3 ranked documents + - relevant_text_ids (list[int]): List of indices for the top ranked documents + + Raises: + ValueError: If documents list is empty + RuntimeError: If cross-encoder model fails to load or rank documents + """ + relevant_text = "" + relevant_text_ids = [] + + encoder_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") + ranks = encoder_model.rank(prompt, documents, top_k=3) + for rank in ranks: + relevant_text += documents[rank["corpus_id"]] + relevant_text_ids.append(rank["corpus_id"]) + + return relevant_text, relevant_text_ids + + +if __name__ == "__main__": + # Document Upload Area + with st.sidebar: + st.set_page_config(page_title="RAG Question Answer") + uploaded_file = st.file_uploader( + "**š Upload PDF files for QnA**", type=["pdf"], accept_multiple_files=False + ) + + process = st.button( + "ā”ļø Process", + ) + if uploaded_file and process: + normalize_uploaded_file_name = uploaded_file.name.translate( + str.maketrans({"-": "_", ".": "_", " ": "_"}) + ) + all_splits = process_document(uploaded_file) + add_to_vector_collection(all_splits, normalize_uploaded_file_name) + + # Question and Answer Area + st.header("š£ļø RAG Question Answer") + prompt = st.text_area("**Ask a question related to your document:**") + ask = st.button( + "š„ Ask", + ) + + if ask and prompt: + results = query_collection(prompt) + context = results.get("documents")[0] + relevant_text, relevant_text_ids = re_rank_cross_encoders(context) + response = call_llm(context=relevant_text, prompt=prompt) + st.write_stream(response) + + with st.expander("See retrieved documents"): + st.write(results) + + with st.expander("See most relevant document ids"): + st.write(relevant_text_ids) + st.write(relevant_text) diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt new file mode 100644 index 0000000..6821d39 --- /dev/null +++ b/requirements/requirements-dev.txt @@ -0,0 +1 @@ +ruff==0.7.4 diff --git a/requirements/requirements.txt b/requirements/requirements.txt new file mode 100644 index 0000000..24e2c7b --- /dev/null +++ b/requirements/requirements.txt @@ -0,0 +1,7 @@ +ollama==0.3.3 # Local inference +chromadb==0.5.20 # Vector Database +sentence-transformers==3.3.1 # CrossEncoder Re-ranking +streamlit==1.40.1 # Application UI +PyMuPDF==1.24.14 # PDF Document loader +langchain-community==0.3.7 # Utils for text splitting + |