Эх сурвалжийг харах

make each history unique per session

galo 3 сар өмнө
parent
commit
074086c231
2 өөрчлөгдсөн 25 нэмэгдсэн , 30 устгасан
  1. 3 2
      app/api/chat.py
  2. 22 28
      app/services/qa.py

+ 3 - 2
app/api/chat.py

@@ -38,13 +38,14 @@ async def ask(session_id: str = Form("default_session"), question: str = Form(..
 
 @router.get("/health")
 async def health():
-    current_time = datetime.utcnow().isoformat() + "Z"  # UTC time with 'Z' suffix
+    current_time = datetime.utcnow().isoformat() + "Z"
     log_entry = {
         "timestamp": current_time,
         "level": "INFO",
         "message": "Health check successful",
         "status": "healthy",
         "service": "chat-api",
-        "version": "1.0.0",  # Example version, adjust as needed
+        "version": "1.0.0",
+        "environment": os.getenv("ENVIRONMENT", "development")
     }
     return log_entry

+ 22 - 28
app/services/qa.py

@@ -13,6 +13,9 @@ from tempfile import NamedTemporaryFile
 # Load environment variables
 load_dotenv()
 
+# Global dictionary to store memory instances per session
+session_memories = {}
+
 # Initialize Ollama LLM and Embeddings
 llm = Ollama(model="tinyllama", temperature=0.7)
 embeddings = OllamaEmbeddings(model="tinyllama")
@@ -29,66 +32,57 @@ def index_file(file_content: bytes, file_name: str):
     loader = TextLoader(temp_file_path)
     documents = loader.load()
 
-    # Split documents into chunks
     text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
     chunks = text_splitter.split_documents(documents)
 
-    # Add to vector store
     vector_store.add_documents(chunks)
 
-    # Clean up temp file
     os.unlink(temp_file_path)
 
 # Define prompt templates
-def get_prompt_with_history(memory):
+def get_prompt_with_history(session_id):
+    memory = session_memories.get(session_id)
+    if not memory:
+        memory = ConversationBufferWindowMemory(k=3)  # Keep last 3 turns
+        session_memories[session_id] = memory
     return PromptTemplate(
-        input_variables=["history", "question"],
+        input_variables=["question", "history"],
         template=f"Previous conversation:\n{{history}}\n\nResponda à seguinte pergunta: {{question}}"
     )
 
-def get_prompt_with_history_and_docs(memory, docs):
+def get_prompt_with_history_and_docs(session_id, docs):
+    memory = session_memories.get(session_id)
+    if not memory:
+        memory = ConversationBufferWindowMemory(k=3)
+        session_memories[session_id] = memory
+    history_text = memory.buffer if memory.buffer else "No previous conversation."
     docs_text = "\n".join([f"Source: {doc.page_content}" for doc in docs]) if docs else "No relevant documents found."
     return PromptTemplate(
-        input_variables=["history", "question"],
-        template=f"Previous conversation:\n{{history}}\n\nRelevant documents:\n{docs_text}\n\nResponda à seguinte pergunta usando as fontes relevantes e citando trechos como fontes: {{question}}"
+        input_variables=["question", "history"],
+        template=f"Previous conversation:\n{history_text}\n\nRelevant documents:\n{docs_text}\n\nResponda à seguinte pergunta usando as fontes relevantes e citando trechos como fontes: {{question}}"
     )
 
 def get_answer(session_id: str, question: str) -> str:
-    # Get or initialize memory for this session
-    memory = ConversationBufferWindowMemory(memory_key="history", input_key="question", k=3, session_id=session_id)
-    
-    # Create chain with dynamic prompt including history
-    prompt = get_prompt_with_history(memory)
+    memory = session_memories.get(session_id, ConversationBufferWindowMemory(k=3))
+    prompt = get_prompt_with_history(session_id)
     chain = LLMChain(llm=llm, prompt=prompt, memory=memory)
-    
-    # Get response
     response = chain.run(question=question)
-    response = response[:100] if len(response) > 100 else response  # Truncate if needed
-    
+    response = response[:100] if len(response) > 100 else response
     return response
 
-# RAG function for /ask endpoint
 def ask_rag(session_id: str, question: str, file_content: bytes = None, file_name: str = None) -> dict:
-    # Get or initialize memory for this session
-    memory = ConversationBufferWindowMemory(memory_key="history", input_key="question", k=3, session_id=session_id)
-    
     if file_content and file_name:
         index_file(file_content, file_name)
     
-    # Retrieve relevant documents
+    memory = session_memories.get(session_id, ConversationBufferWindowMemory(k=3))
     docs = vector_store.similarity_search(question, k=3)
-    
-    # Create chain with dynamic prompt including history and docs
-    prompt = get_prompt_with_history_and_docs(memory, docs)
+    prompt = get_prompt_with_history_and_docs(session_id, docs)
     chain = LLMChain(llm=llm, prompt=prompt, memory=memory)
     
-    # Get response
     response = chain.run(question=question)
     response = response[:100] if len(response) > 100 else response
     
-    # Prepare sources
     sources = [doc.page_content for doc in docs]
-    
     return {"answer": response, "sources": sources}
 
 if __name__ == "__main__":