|
@@ -1,22 +1,47 @@
|
|
|
from langchain_community.llms import Ollama
|
|
from langchain_community.llms import Ollama
|
|
|
from langchain.chains import LLMChain
|
|
from langchain.chains import LLMChain
|
|
|
from langchain.prompts import PromptTemplate
|
|
from langchain.prompts import PromptTemplate
|
|
|
|
|
+from langchain_community.vectorstores import Chroma
|
|
|
|
|
+from langchain_community.embeddings import OllamaEmbeddings
|
|
|
|
|
+from langchain_community.document_loaders import TextLoader
|
|
|
|
|
+from langchain_text_splitters import CharacterTextSplitter
|
|
|
import os
|
|
import os
|
|
|
from dotenv import load_dotenv
|
|
from dotenv import load_dotenv
|
|
|
|
|
+from tempfile import NamedTemporaryFile
|
|
|
|
|
|
|
|
# Load environment variables
|
|
# Load environment variables
|
|
|
load_dotenv()
|
|
load_dotenv()
|
|
|
|
|
|
|
|
-# Global dictionary to store chat history (in-memory for simplicity)
|
|
|
|
|
|
|
+# Global dictionary for chat history
|
|
|
chat_history = {}
|
|
chat_history = {}
|
|
|
|
|
|
|
|
-# Initialize Ollama LLM with tinyllama
|
|
|
|
|
-llm = Ollama(
|
|
|
|
|
- model="tinyllama",
|
|
|
|
|
- temperature=0.7
|
|
|
|
|
-)
|
|
|
|
|
|
|
+# Initialize Ollama LLM and Embeddings
|
|
|
|
|
+llm = Ollama(model="tinyllama", temperature=0.7)
|
|
|
|
|
+embeddings = OllamaEmbeddings(model="tinyllama")
|
|
|
|
|
|
|
|
-# Define a prompt template that includes history
|
|
|
|
|
|
|
+# Initialize global Chroma vector store (in-memory)
|
|
|
|
|
+vector_store = Chroma.from_texts([""], embeddings) # Initialize empty store
|
|
|
|
|
+
|
|
|
|
|
+# Function to index uploaded file
|
|
|
|
|
+def index_file(file_content: bytes, file_name: str):
|
|
|
|
|
+ with NamedTemporaryFile(delete=False, suffix=os.path.splitext(file_name)[1]) as temp_file:
|
|
|
|
|
+ temp_file.write(file_content)
|
|
|
|
|
+ temp_file_path = temp_file.name
|
|
|
|
|
+
|
|
|
|
|
+ loader = TextLoader(temp_file_path)
|
|
|
|
|
+ documents = loader.load()
|
|
|
|
|
+
|
|
|
|
|
+ # Split documents into chunks
|
|
|
|
|
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
|
|
|
|
+ chunks = text_splitter.split_documents(documents)
|
|
|
|
|
+
|
|
|
|
|
+ # Add to vector store
|
|
|
|
|
+ vector_store.add_documents(chunks)
|
|
|
|
|
+
|
|
|
|
|
+ # Clean up temp file
|
|
|
|
|
+ os.unlink(temp_file_path)
|
|
|
|
|
+
|
|
|
|
|
+# Define prompt templates
|
|
|
def get_prompt_with_history(session_id):
|
|
def get_prompt_with_history(session_id):
|
|
|
history = chat_history.get(session_id, [])
|
|
history = chat_history.get(session_id, [])
|
|
|
history_text = "\n".join([f"User: {msg['question']}\nAI: {msg['answer']}" for msg in history]) if history else "No previous conversation."
|
|
history_text = "\n".join([f"User: {msg['question']}\nAI: {msg['answer']}" for msg in history]) if history else "No previous conversation."
|
|
@@ -25,6 +50,15 @@ def get_prompt_with_history(session_id):
|
|
|
template=f"Previous conversation:\n{history_text}\n\nResponda à seguinte pergunta: {{question}}"
|
|
template=f"Previous conversation:\n{history_text}\n\nResponda à seguinte pergunta: {{question}}"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+def get_prompt_with_history_and_docs(session_id, docs):
|
|
|
|
|
+ history = chat_history.get(session_id, [])
|
|
|
|
|
+ history_text = "\n".join([f"User: {msg['question']}\nAI: {msg['answer']}" for msg in history]) if history else "No previous conversation."
|
|
|
|
|
+ docs_text = "\n".join([f"Source: {doc.page_content}" for doc in docs]) if docs else "No relevant documents found."
|
|
|
|
|
+ return PromptTemplate(
|
|
|
|
|
+ input_variables=["question"],
|
|
|
|
|
+ template=f"Previous conversation:\n{history_text}\n\nRelevant documents:\n{docs_text}\n\nResponda à seguinte pergunta usando as fontes relevantes e citando trechos como fontes: {{question}}"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
def get_answer(session_id: str, question: str) -> str:
|
|
def get_answer(session_id: str, question: str) -> str:
|
|
|
# Get or initialize chat history for this session
|
|
# Get or initialize chat history for this session
|
|
|
if session_id not in chat_history:
|
|
if session_id not in chat_history:
|
|
@@ -43,6 +77,28 @@ def get_answer(session_id: str, question: str) -> str:
|
|
|
|
|
|
|
|
return response
|
|
return response
|
|
|
|
|
|
|
|
|
|
+# RAG function for /ask endpoint
|
|
|
|
|
+def ask_rag(session_id: str, question: str, file_content: bytes = None, file_name: str = None) -> dict:
|
|
|
|
|
+ if file_content and file_name:
|
|
|
|
|
+ index_file(file_content, file_name)
|
|
|
|
|
+
|
|
|
|
|
+ if session_id not in chat_history:
|
|
|
|
|
+ chat_history[session_id] = []
|
|
|
|
|
+
|
|
|
|
|
+ docs = vector_store.similarity_search(question, k=3)
|
|
|
|
|
+
|
|
|
|
|
+ prompt = get_prompt_with_history_and_docs(session_id, docs)
|
|
|
|
|
+ chain = LLMChain(llm=llm, prompt=prompt)
|
|
|
|
|
+
|
|
|
|
|
+ response = chain.run(question=question)
|
|
|
|
|
+ response = response[:100] if len(response) > 100 else response
|
|
|
|
|
+
|
|
|
|
|
+ chat_history[session_id].append({"question": question, "answer": response})
|
|
|
|
|
+
|
|
|
|
|
+ sources = [doc.page_content for doc in docs]
|
|
|
|
|
+
|
|
|
|
|
+ return {"answer": response, "sources": sources}
|
|
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
session_id = "test_session"
|
|
session_id = "test_session"
|
|
|
print(get_answer(session_id, "Qual a capital da França?"))
|
|
print(get_answer(session_id, "Qual a capital da França?"))
|