aboutsummaryrefslogtreecommitdiff
path: root/app.py
diff options
context:
space:
mode:
Diffstat (limited to 'app.py')
-rw-r--r--app.py63
1 files changed, 55 insertions, 8 deletions
diff --git a/app.py b/app.py
index be25e93..aaaaeb5 100644
--- a/app.py
+++ b/app.py
@@ -1,9 +1,12 @@
import os
import tempfile
+import subprocess
import chromadb
import ollama
import streamlit as st
+import Ollama
+
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
@@ -159,7 +162,7 @@ def call_llm(context: str, prompt: str):
OllamaError: If there are issues communicating with the Ollama API
"""
response = ollama.chat(
- model="granite3-dense:latest",
+ model="phi3:latest",
stream=True,
messages=[
{
@@ -172,11 +175,19 @@ def call_llm(context: str, prompt: str):
},
],
)
- for chunk in response:
- if chunk["done"] is False:
- yield chunk["message"]["content"]
- else:
- break
+
+ if "full_response" not in st.session_state:
+ st.session_state["full_response"] = ''
+
+ def response_generator():
+ for chunk in response:
+ if chunk["done"] is False:
+ text_chunk = chunk["message"]["content"]
+ st.session_state["full_response"] += text_chunk
+ yield text_chunk
+ else:
+ break
+ return response_generator()
def re_rank_cross_encoders(documents: list[str]) -> tuple[str, list[int]]:
@@ -236,11 +247,47 @@ if __name__ == "__main__":
)
if ask and prompt:
+ #process = subprocess.run(['wsl-notify-send.exe', 'inside streamlit'])
results = query_collection(prompt)
context = results.get("documents")[0]
relevant_text, relevant_text_ids = re_rank_cross_encoders(context)
- response = call_llm(context=relevant_text, prompt=prompt)
- st.write_stream(response)
+ st.session_state["full_response"] = ""
+
+ response_stream = call_llm(context=relevant_text, prompt=prompt)
+ st.write_stream(response_stream)
+
+ button = st.button('click me')
+ if button:
+ st.write('clicked')
+ if "full_response" in st.session_state and st.session_state["full_response"]:
+ create_graph = st.button("Create Knowledge Graph")
+
+
+ if create_graph:
+ st.write("✅ Button Clicked! Creating temporary file...")
+
+ with tempfile.NamedTemporaryFile(delete=False, mode="w", encoding="utf-8", suffix=".txt") as temp_file:
+ temp_file.write(st.session_state["full_response"])
+ temp_file_path = temp_file.name # Get the temporary file path
+
+ try:
+ process = subprocess.run(
+ ["python", "extract.py", temp_file_path], # Pass the file instead of long text
+ check=True,
+ text=True,
+ capture_output=True
+ )
+
+ st.write("📜 Output:", process.stdout)
+ st.write("⚠️ Errors:", process.stderr)
+
+ if process.returncode == 0:
+ st.success("✅ Knowledge Graph Created Successfully!")
+ else:
+ st.error(f"❌ extract.py failed with return code {process.returncode}")
+
+ except subprocess.CalledProcessError as e:
+ st.error(f"❌ Error while running extract.py: {e}")
with st.expander("See retrieved documents"):
st.write(results)