aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Ollama/__init__.py0
-rw-r--r--Ollama/client.py224
-rw-r--r--README.md284
-rw-r--r--app.py63
-rw-r--r--extract.py198
-rw-r--r--gradio-app.py239
-rw-r--r--helpers/__init__0
-rw-r--r--helpers/df_helpers.py71
-rw-r--r--helpers/prompts.py74
9 files changed, 1008 insertions, 145 deletions
diff --git a/Ollama/__init__.py b/Ollama/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/Ollama/__init__.py
diff --git a/Ollama/client.py b/Ollama/client.py
new file mode 100644
index 0000000..9f4e336
--- /dev/null
+++ b/Ollama/client.py
@@ -0,0 +1,224 @@
+import os
+import json
+import requests
+
+BASE_URL = os.environ.get('OLLAMA_HOST', 'http://localhost:11434')
+
+# Generate a response for a given prompt with a provided model. This is a streaming endpoint, so will be a series of responses.
+# The final response object will include statistics and additional data from the request. Use the callback function to override
+# the default handler.
+def generate(model_name, prompt, system=None, template=None, context=None, options=None, callback=None):
+ try:
+ url = f"{BASE_URL}/api/generate"
+ payload = {
+ "model": model_name,
+ "prompt": prompt,
+ "system": system,
+ "template": template,
+ "context": context,
+ "options": options
+ }
+
+ # Remove keys with None values
+ payload = {k: v for k, v in payload.items() if v is not None}
+
+ with requests.post(url, json=payload, stream=True) as response:
+ response.raise_for_status()
+
+ # Creating a variable to hold the context history of the final chunk
+ final_context = None
+
+ # Variable to hold concatenated response strings if no callback is provided
+ full_response = ""
+
+ # Iterating over the response line by line and displaying the details
+ for line in response.iter_lines():
+ if line:
+ # Parsing each line (JSON chunk) and extracting the details
+ chunk = json.loads(line)
+
+ # If a callback function is provided, call it with the chunk
+ if callback:
+ callback(chunk)
+ else:
+ # If this is not the last chunk, add the "response" field value to full_response and print it
+ if not chunk.get("done"):
+ response_piece = chunk.get("response", "")
+ full_response += response_piece
+ print(response_piece, end="", flush=True)
+
+ # Check if it's the last chunk (done is true)
+ if chunk.get("done"):
+ final_context = chunk.get("context")
+
+ # Return the full response and the final context
+ return full_response, final_context
+ except requests.exceptions.RequestException as e:
+ print(f"An error occurred: {e}")
+ return None, None
+
+# Create a model from a Modelfile. Use the callback function to override the default handler.
+def create(model_name, model_path, callback=None):
+ try:
+ url = f"{BASE_URL}/api/create"
+ payload = {"name": model_name, "path": model_path}
+
+ # Making a POST request with the stream parameter set to True to handle streaming responses
+ with requests.post(url, json=payload, stream=True) as response:
+ response.raise_for_status()
+
+ # Iterating over the response line by line and displaying the status
+ for line in response.iter_lines():
+ if line:
+ # Parsing each line (JSON chunk) and extracting the status
+ chunk = json.loads(line)
+
+ if callback:
+ callback(chunk)
+ else:
+ print(f"Status: {chunk.get('status')}")
+ except requests.exceptions.RequestException as e:
+ print(f"An error occurred: {e}")
+
+# Pull a model from a the model registry. Cancelled pulls are resumed from where they left off, and multiple
+# calls to will share the same download progress. Use the callback function to override the default handler.
+def pull(model_name, insecure=False, callback=None):
+ try:
+ url = f"{BASE_URL}/api/pull"
+ payload = {
+ "name": model_name,
+ "insecure": insecure
+ }
+
+ # Making a POST request with the stream parameter set to True to handle streaming responses
+ with requests.post(url, json=payload, stream=True) as response:
+ response.raise_for_status()
+
+ # Iterating over the response line by line and displaying the details
+ for line in response.iter_lines():
+ if line:
+ # Parsing each line (JSON chunk) and extracting the details
+ chunk = json.loads(line)
+
+ # If a callback function is provided, call it with the chunk
+ if callback:
+ callback(chunk)
+ else:
+ # Print the status message directly to the console
+ print(chunk.get('status', ''), end='', flush=True)
+
+ # If there's layer data, you might also want to print that (adjust as necessary)
+ if 'digest' in chunk:
+ print(f" - Digest: {chunk['digest']}", end='', flush=True)
+ print(f" - Total: {chunk['total']}", end='', flush=True)
+ print(f" - Completed: {chunk['completed']}", end='\n', flush=True)
+ else:
+ print()
+ except requests.exceptions.RequestException as e:
+ print(f"An error occurred: {e}")
+
+# Push a model to the model registry. Use the callback function to override the default handler.
+def push(model_name, insecure=False, callback=None):
+ try:
+ url = f"{BASE_URL}/api/push"
+ payload = {
+ "name": model_name,
+ "insecure": insecure
+ }
+
+ # Making a POST request with the stream parameter set to True to handle streaming responses
+ with requests.post(url, json=payload, stream=True) as response:
+ response.raise_for_status()
+
+ # Iterating over the response line by line and displaying the details
+ for line in response.iter_lines():
+ if line:
+ # Parsing each line (JSON chunk) and extracting the details
+ chunk = json.loads(line)
+
+ # If a callback function is provided, call it with the chunk
+ if callback:
+ callback(chunk)
+ else:
+ # Print the status message directly to the console
+ print(chunk.get('status', ''), end='', flush=True)
+
+ # If there's layer data, you might also want to print that (adjust as necessary)
+ if 'digest' in chunk:
+ print(f" - Digest: {chunk['digest']}", end='', flush=True)
+ print(f" - Total: {chunk['total']}", end='', flush=True)
+ print(f" - Completed: {chunk['completed']}", end='\n', flush=True)
+ else:
+ print()
+ except requests.exceptions.RequestException as e:
+ print(f"An error occurred: {e}")
+
+# List models that are available locally.
+def list():
+ try:
+ response = requests.get(f"{BASE_URL}/api/tags")
+ response.raise_for_status()
+ data = response.json()
+ models = data.get('models', [])
+ return models
+
+ except requests.exceptions.RequestException as e:
+ print(f"An error occurred: {e}")
+ return None
+
+# Copy a model. Creates a model with another name from an existing model.
+def copy(source, destination):
+ try:
+ # Create the JSON payload
+ payload = {
+ "source": source,
+ "destination": destination
+ }
+
+ response = requests.post(f"{BASE_URL}/api/copy", json=payload)
+ response.raise_for_status()
+
+ # If the request was successful, return a message indicating that the copy was successful
+ return "Copy successful"
+
+ except requests.exceptions.RequestException as e:
+ print(f"An error occurred: {e}")
+ return None
+
+# Delete a model and its data.
+def delete(model_name):
+ try:
+ url = f"{BASE_URL}/api/delete"
+ payload = {"name": model_name}
+ response = requests.delete(url, json=payload)
+ response.raise_for_status()
+ return "Delete successful"
+ except requests.exceptions.RequestException as e:
+ print(f"An error occurred: {e}")
+ return None
+
+# Show info about a model.
+def show(model_name):
+ try:
+ url = f"{BASE_URL}/api/show"
+ payload = {"name": model_name}
+ response = requests.post(url, json=payload)
+ response.raise_for_status()
+
+ # Parse the JSON response and return it
+ data = response.json()
+ return data
+ except requests.exceptions.RequestException as e:
+ print(f"An error occurred: {e}")
+ return None
+
+def heartbeat():
+ try:
+ url = f"{BASE_URL}/"
+ response = requests.head(url)
+ response.raise_for_status()
+ return "Ollama is running"
+ except requests.exceptions.RequestException as e:
+ print(f"An error occurred: {e}")
+ return "Ollama is not running"
+
diff --git a/README.md b/README.md
index abaa239..f58c303 100644
--- a/README.md
+++ b/README.md
@@ -1,141 +1,151 @@
-# **Document Ingestion and Semantic Query System Using Retrieval-Augmented Generation (RAG)**
-
-## **Overview**
-This application implements a **Retrieval-Augmented Generation (RAG) based Question Answering System** using Streamlit for the user interface, ChromaDB for vector storage, and Ollama for generating responses. The system allows users to upload **PDF documents**, process them into **text chunks**, store them as **vector embeddings**, and retrieve relevant information to generate AI-powered responses.
-
----
-
-## **System Components**
-
-### **1. File Processing and Text Chunking**
-**Function:** `process_document(uploaded_file: UploadedFile) -> list[Document]`
-
-- Takes a user-uploaded **PDF file** and processes it into **smaller text chunks**.
-- Uses **PyMuPDFLoader** to extract text from PDFs.
-- Splits extracted text into **overlapping segments** using **RecursiveCharacterTextSplitter**.
-- Returns a list of **Document objects** containing text chunks and metadata.
-
-**Key Steps:**
-1. Save uploaded file to a **temporary file**.
-2. Load content using **PyMuPDFLoader**.
-3. Split text using **RecursiveCharacterTextSplitter**.
-4. Delete the temporary file.
-5. Return the **list of Document objects**.
-
----
-
-### **2. Vector Storage and Retrieval (ChromaDB)**
-
-#### **Creating a ChromaDB Collection**
-**Function:** `get_vector_collection() -> chromadb.Collection`
-
-- Initializes **ChromaDB** with a **persistent vector store**.
-- Uses **OllamaEmbeddingFunction** to generate vector embeddings.
-- Retrieves or creates a collection for storing **document embeddings**.
-- Uses **cosine similarity** for querying documents.
-
-**Key Steps:**
-1. Define **OllamaEmbeddingFunction** for embedding generation.
-2. Initialize **ChromaDB PersistentClient**.
-3. Retrieve or create a **ChromaDB collection** for storing vectors.
-4. Return the **collection object**.
-
-#### **Adding Documents to Vector Store**
-**Function:** `add_to_vector_collection(all_splits: list[Document], file_name: str)`
-
-- Takes a list of document chunks and stores them in **ChromaDB**.
-- Each document is stored with **unique IDs** based on file name.
-- Success message displayed via **Streamlit**.
-
-**Key Steps:**
-1. Retrieve ChromaDB collection using `get_vector_collection()`.
-2. Convert document chunks into a list of **text embeddings, metadata, and unique IDs**.
-3. Use `upsert()` to store document embeddings.
-4. Display success message.
-
-#### **Querying the Vector Collection**
-**Function:** `query_collection(prompt: str, n_results: int = 10) -> dict`
-
-- Queries **ChromaDB** with a user-provided search query.
-- Returns the **top n most relevant documents** based on similarity.
-
-**Key Steps:**
-1. Retrieve ChromaDB collection.
-2. Perform query using `collection.query()`.
-3. Return **retrieved documents and metadata**.
-
----
-
-### **3. Language Model Interaction (Ollama API)**
-
-#### **Generating Responses using the AI Model**
-**Function:** `call_llm(context: str, prompt: str)`
-
-- Calls **Ollama**'s language model to generate a **context-aware response**.
-- Uses a **system prompt** to guide the model’s behavior.
-- Streams the AI-generated response in **chunks**.
-
-**Key Steps:**
-1. Send **system prompt** and user query to **Ollama**.
-2. Retrieve and yield streamed responses.
-3. Display results in **Streamlit**.
+# extract.py
+
+## Overview
+This program processes text documents, extracts key concepts using a language model, constructs a graph representation of these concepts, and visualizes the resulting network using Pyvis and NetworkX. The extracted relationships between terms are stored in CSV files, and the final graph is displayed in an interactive HTML file.
+
+## Dependencies
+The program requires the following Python libraries:
+- `pyvis.network`
+- `seaborn`
+- `networkx`
+- `pandas`
+- `numpy`
+- `os`
+- `pathlib`
+- `random`
+- `sys`
+- `subprocess`
+- `langchain.document_loaders`
+- `langchain.text_splitter`
+- `helpers.df_helpers`
+
+## Workflow
+
+### 1. Input Handling
+The program expects command-line arguments containing text data. It stores the input data in a specified directory (`data_input`) and creates necessary output directories (`data_output`).
+
+### 2. Document Loading and Splitting
+- The program loads documents using `langchain.document_loaders.DirectoryLoader`.
+- Text is split into chunks using `RecursiveCharacterTextSplitter` with a chunk size of 1500 and overlap of 150 characters.
+- The extracted text chunks are converted into a Pandas DataFrame.
+
+### 3. Graph Generation
+- If `regenerate` is set to `True`, extracted text chunks are processed to generate a concept graph using `df2Graph`.
+- The relationships are stored in a CSV file (`graph.csv`).
+- The extracted text chunks are stored in `chunks.csv`.
+
+### 4. Contextual Proximity Calculation
+The `contextual_proximity` function:
+- Establishes relationships between terms appearing in the same text chunk.
+- Generates additional edges in the graph based on co-occurrence in chunks.
+- Drops edges with only one occurrence.
+- Assigns the label `contextual proximity` to these relationships.
+
+### 5. Graph Construction
+- A `networkx.Graph` object is created.
+- Nodes and edges are added, with edge weights normalized by dividing by 4.
+- Communities in the graph are detected using the Girvan-Newman algorithm.
+- Each community is assigned a unique color.
+
+### 6. Graph Visualization
+- Pyvis is used to create an interactive visualization of the graph.
+- The visualization is saved as `index.html` inside the `docs` directory.
+- The layout uses the `force_atlas_2based` algorithm for optimal positioning.
+
+## Output
+- Processed document data (`chunks.csv`).
+- Extracted concept relationships (`graph.csv`).
+- Interactive graph visualization (`index.html`).
+- Notifications are sent via `wsl-notify-send.exe` when processing starts and completes.
+
+## Usage
+Execute the script with an argument containing text input:
+```bash
+python extract.py path/to/file
+```
+
+## Notes
+- The program creates necessary directories if they do not exist.
+- If `regenerate` is `False`, the program reads precomputed relationships from `graph.csv` instead of generating them anew.
+- Community detection enhances graph visualization by grouping related terms.
+- The visualization can be viewed in a web browser by opening `docs/index.html`.
---
-### **4. Cross-Encoder Based Re-Ranking**
-**Function:** `re_rank_cross_encoders(documents: list[str]) -> tuple[str, list[int]]`
-
-- Uses **CrossEncoder (MS MARCO MiniLM model)** to **re-rank retrieved documents**.
-- Selects the **top 3 most relevant documents**.
-- Returns **concatenated relevant text** and **document indices**.
-
-**Key Steps:**
-1. Load **MS MARCO MiniLM CrossEncoder model**.
-2. Rank documents using **cross-encoder re-ranking**.
-3. Extract the **top-ranked documents**.
-4. Return **concatenated text** and **indices**.
-
----
-
-## **User Interface (Streamlit)**
-
-### **1. Document Uploading and Processing**
-- Sidebar allows **PDF file upload**.
-- User clicks **Process** to extract text and store embeddings.
-- File name is **normalized** before processing.
-- Extracted **text chunks** are stored in **ChromaDB**.
-
-### **2. Question Answering System**
-- Main interface displays a **text area** for users to enter questions.
-- Clicking **Ask** triggers the retrieval and response generation process:
- 1. **Query ChromaDB** to retrieve relevant documents.
- 2. **Re-rank documents** using **cross-encoder**.
- 3. **Pass relevant text** and **question** to the **LLM**.
- 4. Stream and display the AI-generated response.
- 5. Provide options to view **retrieved documents and rankings**.
-
----
-
-## **Technologies Used**
-- **Streamlit** → UI framework for interactive user interface.
-- **PyMuPDF** → PDF text extraction.
-- **ChromaDB** → Vector database for semantic search.
-- **Ollama** → LLM API for generating responses.
-- **LangChain** → Document processing utilities.
-- **Sentence Transformers (CrossEncoder)** → Document re-ranking.
-
----
-
-## **Error Handling & Edge Cases**
-- **File I/O Errors**: Proper handling of **temporary file read/write issues**.
-- **ChromaDB Errors**: Ensures **database consistency and query failures** are managed.
-- **Ollama API Failures**: Detects and **handles API unavailability or timeouts**.
-- **Empty Document Handling**: Ensures that **no empty files** are processed.
-- **Invalid Queries**: Provides **feedback for low-relevance queries**.
-
----
-
-## **Conclusion**
-This application provides a **RAG-based interactive Q&A system**, leveraging **retrieval, ranking, and generation** techniques to deliver highly **relevant AI-generated responses**. The architecture ensures efficient document processing, vector storage, and intelligent answer generation using state-of-the-art models and embeddings.
-
+# gradio-app.py
+
+## Overview
+This program implements a Retrieval-Augmented Generation (RAG) system that allows users to upload PDF documents, extract and store textual information in a vector database, and query the system to retrieve contextually relevant information. It also integrates a knowledge graph generation mechanism to visualize extracted knowledge.
+
+## Dependencies
+The program utilizes the following libraries:
+- `gradio`: For building an interactive web-based interface.
+- `chromadb`: For vector storage and retrieval.
+- `ollama`: For handling LLM-based responses.
+- `langchain_community`: For PDF document loading and text processing.
+- `sentence_transformers`: For cross-encoder-based document re-ranking.
+- `subprocess`, `tempfile`, and `os`: For handling system-level tasks.
+
+## Workflow
+1. **Document Processing**
+ - A PDF file is uploaded via the Gradio interface.
+ - The `process_document` function extracts text from the PDF and splits it into chunks using `RecursiveCharacterTextSplitter`.
+ - The extracted text chunks are stored in a ChromaDB vector collection.
+
+2. **Query Processing**
+ - A user enters a query via the Gradio interface.
+ - The `query_collection` function retrieves relevant text chunks from the vector collection.
+ - The retrieved chunks are re-ranked using a cross-encoder model.
+ - The most relevant text is passed to an LLM for generating a response.
+
+3. **Knowledge Graph Generation**
+ - The generated response is saved temporarily.
+ - The `extract.py` script is executed to create a knowledge graph.
+ - The system notifies the user of success or failure.
+
+## Core Functions
+### `process_document(file_path: str) -> list[Document]`
+Extracts text from a PDF and splits it into chunks for further processing.
+
+### `get_vector_collection() -> chromadb.Collection`
+Creates or retrieves a ChromaDB collection for vector-based semantic search.
+
+### `add_to_vector_collection(all_splits: list[Document], file_name: str)`
+Adds processed document chunks to the vector collection with metadata.
+
+### `query_collection(prompt: str, n_results: int = 10) -> dict`
+Queries the vector collection to retrieve contextually relevant documents.
+
+### `call_llm(context, prompt) -> str`
+Passes the context and question to the `deepseek-r1` LLM for response generation.
+
+### `re_rank_cross_encoders(documents: list[str], prompt) -> tuple[str, list[int]]`
+Uses a cross-encoder model to re-rank retrieved documents based on query relevance.
+
+### `process_question(prompt: str) -> str`
+Combines querying, re-ranking, and LLM response generation.
+
+### `create_knowledge_graph(response: str) -> str`
+Executes the `extract.py` script to generate a knowledge graph based on LLM output.
+
+### `process_pdf(file_path: str)`
+Processes an uploaded PDF, extracts text, and adds it to the vector collection.
+
+## Gradio Interface
+The Gradio-based UI consists of:
+- **File Upload Section**: Users upload a PDF for processing.
+- **Query Section**: Users ask questions related to the uploaded content.
+- **Knowledge Graph Section**: Users can generate a visual representation of extracted knowledge.
+
+## Execution
+To run the program:
+```sh
+python script_name.py
+```
+This launches the Gradio interface, allowing document uploads and question answering.
+
+## Error Handling
+- If document processing fails, an error message is displayed in the UI.
+- If no relevant documents are found for a query, the system returns an appropriate message.
+- If knowledge graph generation fails, the error is captured and displayed.
diff --git a/app.py b/app.py
index be25e93..aaaaeb5 100644
--- a/app.py
+++ b/app.py
@@ -1,9 +1,12 @@
import os
import tempfile
+import subprocess
import chromadb
import ollama
import streamlit as st
+import Ollama
+
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
@@ -159,7 +162,7 @@ def call_llm(context: str, prompt: str):
OllamaError: If there are issues communicating with the Ollama API
"""
response = ollama.chat(
- model="granite3-dense:latest",
+ model="phi3:latest",
stream=True,
messages=[
{
@@ -172,11 +175,19 @@ def call_llm(context: str, prompt: str):
},
],
)
- for chunk in response:
- if chunk["done"] is False:
- yield chunk["message"]["content"]
- else:
- break
+
+ if "full_response" not in st.session_state:
+ st.session_state["full_response"] = ''
+
+ def response_generator():
+ for chunk in response:
+ if chunk["done"] is False:
+ text_chunk = chunk["message"]["content"]
+ st.session_state["full_response"] += text_chunk
+ yield text_chunk
+ else:
+ break
+ return response_generator()
def re_rank_cross_encoders(documents: list[str]) -> tuple[str, list[int]]:
@@ -236,11 +247,47 @@ if __name__ == "__main__":
)
if ask and prompt:
+ #process = subprocess.run(['wsl-notify-send.exe', 'inside streamlit'])
results = query_collection(prompt)
context = results.get("documents")[0]
relevant_text, relevant_text_ids = re_rank_cross_encoders(context)
- response = call_llm(context=relevant_text, prompt=prompt)
- st.write_stream(response)
+ st.session_state["full_response"] = ""
+
+ response_stream = call_llm(context=relevant_text, prompt=prompt)
+ st.write_stream(response_stream)
+
+ button = st.button('click me')
+ if button:
+ st.write('clicked')
+ if "full_response" in st.session_state and st.session_state["full_response"]:
+ create_graph = st.button("Create Knowledge Graph")
+
+
+ if create_graph:
+ st.write("✅ Button Clicked! Creating temporary file...")
+
+ with tempfile.NamedTemporaryFile(delete=False, mode="w", encoding="utf-8", suffix=".txt") as temp_file:
+ temp_file.write(st.session_state["full_response"])
+ temp_file_path = temp_file.name # Get the temporary file path
+
+ try:
+ process = subprocess.run(
+ ["python", "extract.py", temp_file_path], # Pass the file instead of long text
+ check=True,
+ text=True,
+ capture_output=True
+ )
+
+ st.write("📜 Output:", process.stdout)
+ st.write("⚠️ Errors:", process.stderr)
+
+ if process.returncode == 0:
+ st.success("✅ Knowledge Graph Created Successfully!")
+ else:
+ st.error(f"❌ extract.py failed with return code {process.returncode}")
+
+ except subprocess.CalledProcessError as e:
+ st.error(f"❌ Error while running extract.py: {e}")
with st.expander("See retrieved documents"):
st.write(results)
diff --git a/extract.py b/extract.py
new file mode 100644
index 0000000..f3523ae
--- /dev/null
+++ b/extract.py
@@ -0,0 +1,198 @@
+from pyvis.network import Network
+import seaborn as sns
+import networkx as nx
+from helpers.df_helpers import graph2Df
+from helpers.df_helpers import df2Graph
+from helpers.df_helpers import documents2Dataframe
+import pandas as pd
+import numpy as np
+import os
+from langchain.document_loaders import PyPDFLoader, UnstructuredPDFLoader, PyPDFium2Loader
+from langchain.document_loaders import PyPDFDirectoryLoader, DirectoryLoader
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+from pathlib import Path
+import random
+import sys
+import subprocess
+
+process = subprocess.run(['wsl-notify-send.exe', 'inside extract.py'])
+args = sys.argv[1:]
+
+# Input data directory
+data_dir = "model_response"
+inputdirectory = Path(f"./data_input/{data_dir}")
+if not os.path.exists(inputdirectory):
+ os.makedirs(inputdirectory)
+# This is where the output csv files will be written
+out_dir = data_dir
+outputdirectory = Path(f"./data_output/{out_dir}")
+if not os.path.exists(outputdirectory):
+ os.makedirs(outputdirectory)
+
+with open(f'{inputdirectory}/response.txt', 'w') as file:
+ file.write(args[0])
+
+# Dir PDF Loader
+# loader = PyPDFDirectoryLoader(inputdirectory)
+# File Loader
+# loader = PyPDFLoader("./data/MedicalDocuments/orf-path_health-n1.pdf")
+loader = DirectoryLoader(inputdirectory, show_progress=True)
+documents = loader.load()
+
+splitter = RecursiveCharacterTextSplitter(
+ chunk_size=1500,
+ chunk_overlap=150,
+ length_function=len,
+ is_separator_regex=False,
+)
+
+pages = splitter.split_documents(documents)
+print("Number of chunks = ", len(pages))
+print(pages[0].page_content)
+
+df = documents2Dataframe(pages)
+print(df.shape)
+df.head()
+
+# This function uses the helpers/prompt function to extract concepts from text
+
+# To regenerate the graph with LLM, set this to True
+regenerate = True
+
+if regenerate:
+ concepts_list = df2Graph(df, model='phi3:latest')
+ dfg1 = graph2Df(concepts_list)
+ if not os.path.exists(outputdirectory):
+ os.makedirs(outputdirectory)
+
+ dfg1.to_csv(outputdirectory/"graph.csv", sep="|", index=False)
+ df.to_csv(outputdirectory/"chunks.csv", sep="|", index=False)
+else:
+ dfg1 = pd.read_csv(outputdirectory/"graph.csv", sep="|")
+
+dfg1.replace("", np.nan, inplace=True)
+dfg1.dropna(subset=["node_1", "node_2", 'edge'], inplace=True)
+dfg1['count'] = 4
+# Increasing the weight of the relation to 4.
+# We will assign the weight of 1 when later the contextual proximity will be calculated.
+print(dfg1.shape)
+dfg1.head()
+
+
+def contextual_proximity(df: pd.DataFrame) -> pd.DataFrame:
+ # Melt the dataframe into a list of nodes
+ dfg_long = pd.melt(
+ df, id_vars=["chunk_id"], value_vars=["node_1", "node_2"], value_name="node"
+ )
+ dfg_long.drop(columns=["variable"], inplace=True)
+ # Self join with chunk id as the key will create a link between terms occuring in the same text chunk.
+ dfg_wide = pd.merge(dfg_long, dfg_long, on="chunk_id",
+ suffixes=("_1", "_2"))
+ # drop self loops
+ self_loops_drop = dfg_wide[dfg_wide["node_1"] == dfg_wide["node_2"]].index
+ dfg2 = dfg_wide.drop(index=self_loops_drop).reset_index(drop=True)
+ # Group and count edges.
+ dfg2 = (
+ dfg2.groupby(["node_1", "node_2"])
+ .agg({"chunk_id": [",".join, "count"]})
+ .reset_index()
+ )
+ dfg2.columns = ["node_1", "node_2", "chunk_id", "count"]
+ dfg2.replace("", np.nan, inplace=True)
+ dfg2.dropna(subset=["node_1", "node_2"], inplace=True)
+ # Drop edges with 1 count
+ dfg2 = dfg2[dfg2["count"] != 1]
+ dfg2["edge"] = "contextual proximity"
+ return dfg2
+
+
+dfg2 = contextual_proximity(dfg1)
+dfg2.tail()
+
+dfg = pd.concat([dfg1, dfg2], axis=0)
+dfg = (
+ dfg.groupby(["node_1", "node_2"])
+ .agg({"chunk_id": ",".join, "edge": ','.join, 'count': 'sum'})
+ .reset_index()
+)
+dfg
+
+nodes = pd.concat([dfg['node_1'], dfg['node_2']], axis=0).unique()
+nodes.shape
+
+G = nx.Graph()
+
+# Add nodes to the graph
+for node in nodes:
+ G.add_node(
+ str(node)
+ )
+
+# Add edges to the graph
+for index, row in dfg.iterrows():
+ G.add_edge(
+ str(row["node_1"]),
+ str(row["node_2"]),
+ title=row["edge"],
+ weight=row['count']/4
+ )
+
+communities_generator = nx.community.girvan_newman(G)
+top_level_communities = next(communities_generator)
+next_level_communities = next(communities_generator)
+communities = sorted(map(sorted, next_level_communities))
+print("Number of Communities = ", len(communities))
+print(communities)
+
+palette = "hls"
+
+# Now add these colors to communities and make another dataframe
+
+
+def colors2Community(communities) -> pd.DataFrame:
+ # Define a color palette
+ p = sns.color_palette(palette, len(communities)).as_hex()
+ random.shuffle(p)
+ rows = []
+ group = 0
+ for community in communities:
+ color = p.pop()
+ group += 1
+ for node in community:
+ rows += [{"node": node, "color": color, "group": group}]
+ df_colors = pd.DataFrame(rows)
+ return df_colors
+
+
+colors = colors2Community(communities)
+colors
+
+for index, row in colors.iterrows():
+ G.nodes[row['node']]['group'] = row['group']
+ G.nodes[row['node']]['color'] = row['color']
+ G.nodes[row['node']]['size'] = G.degree[row['node']]
+
+
+graph_output_directory = "./docs/index.html"
+if not os.path.exists('./docs'):
+ os.makedirs('./docs')
+net = Network(
+ notebook=False,
+ # bgcolor="#1a1a1a",
+ cdn_resources="remote",
+ height="900px",
+ width="100%",
+ select_menu=True,
+ # font_color="#cccccc",
+ filter_menu=False,
+)
+
+net.from_nx(G)
+# net.repulsion(node_distance=150, spring_length=400)
+net.force_atlas_2based(central_gravity=0.015, gravity=-31)
+# net.barnes_hut(gravity=-18100, central_gravity=5.05, spring_length=380)
+net.show_buttons(filter_=["physics"])
+
+net.show(graph_output_directory, notebook=False)
+
+process = subprocess.run(['wsl-notify-send.exe', 'graph generated'])
diff --git a/gradio-app.py b/gradio-app.py
new file mode 100644
index 0000000..52ab111
--- /dev/null
+++ b/gradio-app.py
@@ -0,0 +1,239 @@
+import os
+import tempfile
+import subprocess
+
+import gradio as gr
+import chromadb
+import ollama
+from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction
+from langchain_community.document_loaders import PyMuPDFLoader
+from langchain_core.documents import Document
+from langchain_text_splitters import RecursiveCharacterTextSplitter
+from sentence_transformers import CrossEncoder
+
+system_prompt = """
+You are an AI assistant tasked with providing detailed answers based solely on the given context. Your goal is to analyze the information provided and formulate a comprehensive, well-structured response to the question.
+
+context will be passed as "Context:"
+user question will be passed as "Question:"
+
+To answer the question:
+1. Thoroughly analyze the context, identifying key information relevant to the question.
+2. Organize your thoughts and plan your response to ensure a logical flow of information.
+3. Formulate a detailed answer that directly addresses the question, using only the information provided in the context.
+4. Ensure your answer is comprehensive, covering all relevant aspects found in the context.
+5. If the context doesn't contain sufficient information to fully answer the question, state this clearly in your response.
+
+Format your response as follows:
+1. Use clear, concise language.
+2. Organize your answer into paragraphs for readability.
+3. Use bullet points or numbered lists where appropriate to break down complex information.
+4. If relevant, include any headings or subheadings to structure your response.
+5. Ensure proper grammar, punctuation, and spelling throughout your answer.
+
+Important: Base your entire response solely on the information provided in the context. Do not include any external knowledge or assumptions not present in the given text.
+"""
+
+
+def process_document(file_path: str) -> list[Document]:
+ """Processes an uploaded PDF file by converting it to text chunks.
+
+ Takes an uploaded PDF file, saves it temporarily, loads and splits the content
+ into text chunks using recursive character splitting.
+
+ Args:
+ uploaded_file: A Streamlit UploadedFile object containing the PDF file
+
+ Returns:
+ A list of Document objects containing the chunked text from the PDF
+
+ Raises:
+ IOError: If there are issues reading/writing the temporary file
+ """
+
+ loader = PyMuPDFLoader(file_path)
+ docs = loader.load()
+
+ text_splitter = RecursiveCharacterTextSplitter(
+ chunk_size=400,
+ chunk_overlap=100,
+ separators=["\n\n", "\n", ".", "?", "!", " ", ""],
+ )
+ return text_splitter.split_documents(docs)
+
+
+def get_vector_collection() -> chromadb.Collection:
+ """Gets or creates a ChromaDB collection for vector storage.
+
+ Creates an Ollama embedding function using the nomic-embed-text model and initializes
+ a persistent ChromaDB client. Returns a collection that can be used to store and
+ query document embeddings.
+
+ Returns:
+ chromadb.Collection: A ChromaDB collection configured with the Ollama embedding
+ function and cosine similarity space.
+ """
+ ollama_ef = OllamaEmbeddingFunction(
+ url="http://localhost:11434/api/embeddings",
+ model_name="nomic-embed-text:latest",
+ )
+
+ chroma_client = chromadb.PersistentClient(path="./chromadb")
+ return chroma_client.get_or_create_collection(
+ name="rag_app",
+ embedding_function=ollama_ef,
+ metadata={"hnsw:space": "cosine"},
+ )
+
+
+def add_to_vector_collection(all_splits: list[Document], file_name: str):
+ """Adds document splits to a vector collection for semantic search.
+
+ Takes a list of document splits and adds them to a ChromaDB vector collection
+ along with their metadata and unique IDs based on the filename.
+
+ Args:
+ all_splits: List of Document objects containing text chunks and metadata
+ file_name: String identifier used to generate unique IDs for the chunks
+
+ Returns:
+ None. Displays a success message via Streamlit when complete.
+
+ Raises:
+ ChromaDBError: If there are issues upserting documents to the collection
+ """
+ collection = get_vector_collection()
+ documents, metadatas, ids = [], [], []
+
+ for idx, split in enumerate(all_splits):
+ documents.append(split.page_content)
+ metadatas.append(split.metadata)
+ ids.append(f"{file_name}_{idx}")
+
+ collection.upsert(
+ documents=documents,
+ metadatas=metadatas,
+ ids=ids,
+ )
+
+
+def query_collection(prompt: str, n_results: int = 10):
+ """Queries the vector collection with a given prompt to retrieve relevant documents.
+
+ Args:
+ prompt: The search query text to find relevant documents.
+ n_results: Maximum number of results to return. Defaults to 10.
+
+ Returns:
+ dict: Query results containing documents, distances and metadata from the collection.
+
+ Raises:
+ ChromaDBError: If there are issues querying the collection.
+ """
+ collection = get_vector_collection()
+ results = collection.query(query_texts=[prompt], n_results=n_results)
+ return results
+
+
+def call_llm(context, prompt):
+ response = ollama.chat(
+ model="deepseek-r1:latest",
+ stream=False,
+ messages=[
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": f"Context: {context}, Question: {prompt}"},
+ ],
+ )
+ return response["message"]["content"]
+
+
+def re_rank_cross_encoders(documents: list[str], prompt) -> tuple[str, list[int]]:
+ """Re-ranks documents using a cross-encoder model for more accurate relevance scoring.
+
+ Uses the MS MARCO MiniLM cross-encoder model to re-rank the input documents based on
+ their relevance to the query prompt. Returns the concatenated text of the top 3 most
+ relevant documents along with their indices.
+
+ Args:
+ documents: List of document strings to be re-ranked.
+
+ Returns:
+ tuple: A tuple containing:
+ - relevant_text (str): Concatenated text from the top 3 ranked documents
+ - relevant_text_ids (list[int]): List of indices for the top ranked documents
+
+ Raises:
+ ValueError: If documents list is empty
+ RuntimeError: If cross-encoder model fails to load or rank documents
+ """
+ relevant_text = ""
+ relevant_text_ids = []
+
+ encoder_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
+ ranks = encoder_model.rank(prompt, documents, top_k=3)
+ for rank in ranks:
+ relevant_text += documents[rank["corpus_id"]]
+ relevant_text_ids.append(rank["corpus_id"])
+
+ return relevant_text, relevant_text_ids
+
+
+def process_question(prompt):
+ results = query_collection(prompt)
+ context = results.get("documents")[0]
+ relevant_text, _ = re_rank_cross_encoders(context, prompt)
+ response = call_llm(relevant_text, prompt)
+ return response
+
+
+def create_knowledge_graph(response):
+ with tempfile.NamedTemporaryFile(delete=False, mode="w", encoding="utf-8", suffix=".txt") as temp_file:
+ temp_file.write(response)
+ temp_file_path = temp_file.name
+
+ process = subprocess.run([
+ "python", "extract.py", response
+ ], check=True, text=True, capture_output=True)
+
+ if process.returncode == 0:
+ return "✅ Knowledge Graph Created Successfully!"
+ return f"❌ Failed: {process.stderr}"
+
+
+def process_pdf(file_path):
+ all_splits = process_document(file_path)
+ file_name = os.path.basename(file_path).replace(".", "_").replace(" ", "_")
+ return add_to_vector_collection(all_splits, file_name)
+
+
+def main():
+ with gr.Blocks() as demo:
+ gr.Markdown("## 🗣️ RAG Question Answer System")
+
+ with gr.Row():
+ upload = gr.File(label="Upload PDF")
+ process_button = gr.Button("Process")
+ status_output = gr.Textbox(label="Status")
+
+ process_button.click(process_pdf, inputs=upload, outputs=status_output)
+
+ with gr.Row():
+ question = gr.Textbox(label="Ask a Question")
+ ask_button = gr.Button("Ask")
+ answer_output = gr.Textbox(label="Response")
+
+ ask_button.click(process_question, inputs=question,
+ outputs=answer_output)
+
+ with gr.Row():
+ create_graph_button = gr.Button("Create Knowledge Graph")
+ graph_output = gr.Textbox(label="Graph Status")
+
+ create_graph_button.click(
+ create_knowledge_graph, inputs=answer_output, outputs=graph_output)
+
+ demo.launch()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/helpers/__init__ b/helpers/__init__
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/helpers/__init__
diff --git a/helpers/df_helpers.py b/helpers/df_helpers.py
new file mode 100644
index 0000000..b241df5
--- /dev/null
+++ b/helpers/df_helpers.py
@@ -0,0 +1,71 @@
+import uuid
+import pandas as pd
+import numpy as np
+from .prompts import extractConcepts
+from .prompts import graphPrompt
+
+
+def documents2Dataframe(documents) -> pd.DataFrame:
+ rows = []
+ for chunk in documents:
+ row = {
+ "text": chunk.page_content,
+ **chunk.metadata,
+ "chunk_id": uuid.uuid4().hex,
+ }
+ rows = rows + [row]
+
+ df = pd.DataFrame(rows)
+ return df
+
+
+def df2ConceptsList(dataframe: pd.DataFrame) -> list:
+ # dataframe.reset_index(inplace=True)
+ results = dataframe.apply(
+ lambda row: extractConcepts(
+ row.text, {"chunk_id": row.chunk_id, "type": "concept"}
+ ),
+ axis=1,
+ )
+ # invalid json results in NaN
+ results = results.dropna()
+ results = results.reset_index(drop=True)
+
+ ## Flatten the list of lists to one single list of entities.
+ concept_list = np.concatenate(results).ravel().tolist()
+ return concept_list
+
+
+def concepts2Df(concepts_list) -> pd.DataFrame:
+ ## Remove all NaN entities
+ concepts_dataframe = pd.DataFrame(concepts_list).replace(" ", np.nan)
+ concepts_dataframe = concepts_dataframe.dropna(subset=["entity"])
+ concepts_dataframe["entity"] = concepts_dataframe["entity"].apply(
+ lambda x: x.lower()
+ )
+
+ return concepts_dataframe
+
+
+def df2Graph(dataframe: pd.DataFrame, model=None) -> list:
+ # dataframe.reset_index(inplace=True)
+ results = dataframe.apply(
+ lambda row: graphPrompt(row.text, {"chunk_id": row.chunk_id}, model), axis=1
+ )
+ # invalid json results in NaN
+ results = results.dropna()
+ results = results.reset_index(drop=True)
+
+ ## Flatten the list of lists to one single list of entities.
+ concept_list = np.concatenate(results).ravel().tolist()
+ return concept_list
+
+
+def graph2Df(nodes_list) -> pd.DataFrame:
+ ## Remove all NaN entities
+ graph_dataframe = pd.DataFrame(nodes_list).replace(" ", np.nan)
+ graph_dataframe = graph_dataframe.dropna(subset=["node_1", "node_2"])
+ graph_dataframe["node_1"] = graph_dataframe["node_1"].apply(lambda x: x.lower())
+ graph_dataframe["node_2"] = graph_dataframe["node_2"].apply(lambda x: x.lower())
+
+ return graph_dataframe
diff --git a/helpers/prompts.py b/helpers/prompts.py
new file mode 100644
index 0000000..1c3801e
--- /dev/null
+++ b/helpers/prompts.py
@@ -0,0 +1,74 @@
+import Ollama.client as client
+import json
+import sys
+from yachalk import chalk
+sys.path.append("..")
+
+
+def extractConcepts(prompt: str, metadata={}, model="mistral-openorca:latest"):
+ SYS_PROMPT = (
+ "Your task is extract the key concepts (and non personal entities) mentioned in the given context. "
+ "Extract only the most important and atomistic concepts, if needed break the concepts down to the simpler concepts."
+ "Categorize the concepts in one of the following categories: "
+ "[event, concept, place, object, document, organisation, condition, misc]\n"
+ "Format your output as a list of json with the following format:\n"
+ "[\n"
+ " {\n"
+ ' "entity": The Concept,\n'
+ ' "importance": The concontextual importance of the concept on a scale of 1 to 5 (5 being the highest),\n'
+ ' "category": The Type of Concept,\n'
+ " }, \n"
+ "{ }, \n"
+ "]\n"
+ )
+ response, _ = client.generate(
+ model_name=model, system=SYS_PROMPT, prompt=prompt)
+ try:
+ result = json.loads(response)
+ result = [dict(item, **metadata) for item in result]
+ except:
+ print("\n\nERROR ### Here is the buggy response: ", response, "\n\n")
+ result = None
+ return result
+
+
+def graphPrompt(input: str, metadata={}, model="mistral-openorca:latest"):
+ if model == None:
+ model = "mistral-openorca:latest"
+
+ # model_info = client.show(model_name=model)
+ # print( chalk.blue(model_info))
+
+ SYS_PROMPT = (
+ "You are a network graph maker who extracts terms and their relations from a given context. "
+ "You are provided with a context chunk (delimited by ```) Your task is to extract the ontology "
+ "of terms mentioned in the given context. These terms should represent the key concepts as per the context. \n"
+ "Thought 1: While traversing through each sentence, Think about the key terms mentioned in it.\n"
+ "\tTerms may include object, entity, location, organization, person, \n"
+ "\tcondition, acronym, documents, service, concept, etc.\n"
+ "\tTerms should be as atomistic as possible\n\n"
+ "Thought 2: Think about how these terms can have one on one relation with other terms.\n"
+ "\tTerms that are mentioned in the same sentence or the same paragraph are typically related to each other.\n"
+ "\tTerms can be related to many other terms\n\n"
+ "Thought 3: Find out the relation between each such related pair of terms. \n\n"
+ "Format your output as a list of json. Each element of the list contains a pair of terms"
+ "and the relation between them, like the follwing: \n"
+ "[\n"
+ " {\n"
+ ' "node_1": "A concept from extracted ontology",\n'
+ ' "node_2": "A related concept from extracted ontology",\n'
+ ' "edge": "relationship between the two concepts, node_1 and node_2 in one or two sentences"\n'
+ " }, {...}\n"
+ "]"
+ )
+
+ USER_PROMPT = f"context: ```{input}``` \n\n output: "
+ response, _ = client.generate(
+ model_name=model, system=SYS_PROMPT, prompt=USER_PROMPT)
+ try:
+ result = json.loads(response)
+ result = [dict(item, **metadata) for item in result]
+ except:
+ print("\n\nERROR ### Here is the buggy response: ", response, "\n\n")
+ result = None
+ return result