|
|
@@ -13,6 +13,9 @@ from tempfile import NamedTemporaryFile
|
|
|
# Load environment variables
|
|
|
load_dotenv()
|
|
|
|
|
|
+# Global dictionary to store memory instances per session
|
|
|
+session_memories = {}
|
|
|
+
|
|
|
# Initialize Ollama LLM and Embeddings
|
|
|
llm = Ollama(model="tinyllama", temperature=0.7)
|
|
|
embeddings = OllamaEmbeddings(model="tinyllama")
|
|
|
@@ -29,66 +32,57 @@ def index_file(file_content: bytes, file_name: str):
|
|
|
loader = TextLoader(temp_file_path)
|
|
|
documents = loader.load()
|
|
|
|
|
|
- # Split documents into chunks
|
|
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
|
|
chunks = text_splitter.split_documents(documents)
|
|
|
|
|
|
- # Add to vector store
|
|
|
vector_store.add_documents(chunks)
|
|
|
|
|
|
- # Clean up temp file
|
|
|
os.unlink(temp_file_path)
|
|
|
|
|
|
# Define prompt templates
|
|
|
-def get_prompt_with_history(memory):
|
|
|
+def get_prompt_with_history(session_id):
|
|
|
+ memory = session_memories.get(session_id)
|
|
|
+ if not memory:
|
|
|
+ memory = ConversationBufferWindowMemory(k=3) # Keep last 3 turns
|
|
|
+ session_memories[session_id] = memory
|
|
|
return PromptTemplate(
|
|
|
- input_variables=["history", "question"],
|
|
|
+ input_variables=["question", "history"],
|
|
|
template=f"Previous conversation:\n{{history}}\n\nResponda à seguinte pergunta: {{question}}"
|
|
|
)
|
|
|
|
|
|
-def get_prompt_with_history_and_docs(memory, docs):
|
|
|
+def get_prompt_with_history_and_docs(session_id, docs):
|
|
|
+ memory = session_memories.get(session_id)
|
|
|
+ if not memory:
|
|
|
+ memory = ConversationBufferWindowMemory(k=3)
|
|
|
+ session_memories[session_id] = memory
|
|
|
+ history_text = memory.buffer if memory.buffer else "No previous conversation."
|
|
|
docs_text = "\n".join([f"Source: {doc.page_content}" for doc in docs]) if docs else "No relevant documents found."
|
|
|
return PromptTemplate(
|
|
|
- input_variables=["history", "question"],
|
|
|
- template=f"Previous conversation:\n{{history}}\n\nRelevant documents:\n{docs_text}\n\nResponda à seguinte pergunta usando as fontes relevantes e citando trechos como fontes: {{question}}"
|
|
|
+ input_variables=["question", "history"],
|
|
|
+ template=f"Previous conversation:\n{history_text}\n\nRelevant documents:\n{docs_text}\n\nResponda à seguinte pergunta usando as fontes relevantes e citando trechos como fontes: {{question}}"
|
|
|
)
|
|
|
|
|
|
def get_answer(session_id: str, question: str) -> str:
|
|
|
- # Get or initialize memory for this session
|
|
|
- memory = ConversationBufferWindowMemory(memory_key="history", input_key="question", k=3, session_id=session_id)
|
|
|
-
|
|
|
- # Create chain with dynamic prompt including history
|
|
|
- prompt = get_prompt_with_history(memory)
|
|
|
+ memory = session_memories.get(session_id, ConversationBufferWindowMemory(k=3))
|
|
|
+ prompt = get_prompt_with_history(session_id)
|
|
|
chain = LLMChain(llm=llm, prompt=prompt, memory=memory)
|
|
|
-
|
|
|
- # Get response
|
|
|
response = chain.run(question=question)
|
|
|
- response = response[:100] if len(response) > 100 else response # Truncate if needed
|
|
|
-
|
|
|
+ response = response[:100] if len(response) > 100 else response
|
|
|
return response
|
|
|
|
|
|
-# RAG function for /ask endpoint
|
|
|
def ask_rag(session_id: str, question: str, file_content: bytes = None, file_name: str = None) -> dict:
|
|
|
- # Get or initialize memory for this session
|
|
|
- memory = ConversationBufferWindowMemory(memory_key="history", input_key="question", k=3, session_id=session_id)
|
|
|
-
|
|
|
if file_content and file_name:
|
|
|
index_file(file_content, file_name)
|
|
|
|
|
|
- # Retrieve relevant documents
|
|
|
+ memory = session_memories.get(session_id, ConversationBufferWindowMemory(k=3))
|
|
|
docs = vector_store.similarity_search(question, k=3)
|
|
|
-
|
|
|
- # Create chain with dynamic prompt including history and docs
|
|
|
- prompt = get_prompt_with_history_and_docs(memory, docs)
|
|
|
+ prompt = get_prompt_with_history_and_docs(session_id, docs)
|
|
|
chain = LLMChain(llm=llm, prompt=prompt, memory=memory)
|
|
|
|
|
|
- # Get response
|
|
|
response = chain.run(question=question)
|
|
|
response = response[:100] if len(response) > 100 else response
|
|
|
|
|
|
- # Prepare sources
|
|
|
sources = [doc.page_content for doc in docs]
|
|
|
-
|
|
|
return {"answer": response, "sources": sources}
|
|
|
|
|
|
if __name__ == "__main__":
|