qa.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from langchain_community.llms import Ollama
  2. from langchain.chains import LLMChain
  3. from langchain.prompts import PromptTemplate
  4. from langchain_community.vectorstores import Chroma
  5. from langchain_community.embeddings import OllamaEmbeddings
  6. from langchain_community.document_loaders import TextLoader
  7. from langchain_text_splitters import CharacterTextSplitter
  8. from langchain.memory import ConversationBufferWindowMemory
  9. import os
  10. from dotenv import load_dotenv
  11. from tempfile import NamedTemporaryFile
  12. import logging
  13. # Set up logging
  14. logging.basicConfig(level=logging.DEBUG)
  15. logger = logging.getLogger(__name__)
  16. # Load environment variables
  17. load_dotenv()
  18. # Global dictionary to store memory instances per session
  19. session_memories = {}
  20. # Initialize Ollama LLM and Embeddings
  21. llm = Ollama(model="tinyllama", temperature=0.7)
  22. embeddings = OllamaEmbeddings(model="tinyllama")
  23. # Initialize global Chroma vector store (in-memory)
  24. vector_store = Chroma.from_texts([""], embeddings) # Initialize empty store
  25. # Function to index uploaded file
  26. def index_file(file_content: bytes, file_name: str):
  27. with NamedTemporaryFile(delete=False, suffix=os.path.splitext(file_name)[1]) as temp_file:
  28. temp_file.write(file_content)
  29. temp_file_path = temp_file.name
  30. loader = TextLoader(temp_file_path)
  31. documents = loader.load()
  32. text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
  33. chunks = text_splitter.split_documents(documents)
  34. vector_store.add_documents(chunks)
  35. os.unlink(temp_file_path)
  36. # Define prompt templates
  37. def get_prompt_with_history(session_id):
  38. memory = session_memories.get(session_id)
  39. if not memory:
  40. memory = ConversationBufferWindowMemory(k=3)
  41. session_memories[session_id] = memory
  42. logger.debug(f"Initialized new memory for session_id: {session_id}")
  43. logger.debug(f"Memory buffer for session_id {session_id}: {memory.buffer}")
  44. return PromptTemplate(
  45. input_variables=["question", "history"],
  46. template=f"Previous conversation:\n{{history}}\n\nResponda à seguinte pergunta: {{question}}"
  47. )
  48. def get_prompt_with_history_and_docs(session_id, docs):
  49. memory = session_memories.get(session_id)
  50. if not memory:
  51. memory = ConversationBufferWindowMemory(k=3)
  52. session_memories[session_id] = memory
  53. logger.debug(f"Initialized new memory for session_id: {session_id}")
  54. logger.debug(f"Memory buffer for session_id {session_id}: {memory.buffer}")
  55. history_text = memory.buffer_as_str if hasattr(memory, 'buffer_as_str') else str(memory.buffer) if memory.buffer else "No previous conversation."
  56. docs_text = "\n".join([f"Source: {doc.page_content}" for doc in docs]) if docs else "No relevant documents found."
  57. return PromptTemplate(
  58. input_variables=["question", "history"],
  59. template=f"Previous conversation:\n{history_text}\n\nRelevant documents:\n{docs_text}\n\nResponda à seguinte pergunta usando as fontes relevantes e citando trechos como fontes: {{question}}"
  60. )
  61. def get_answer(session_id: str, question: str) -> str:
  62. memory = session_memories.get(session_id, ConversationBufferWindowMemory(k=3))
  63. if session_id not in session_memories:
  64. session_memories[session_id] = memory
  65. logger.debug(f"New memory assigned for session_id: {session_id}")
  66. prompt = get_prompt_with_history(session_id)
  67. chain = LLMChain(llm=llm, prompt=prompt, memory=memory)
  68. logger.debug(f"Before run - Memory buffer: {memory.buffer}")
  69. response = chain.run(question=question)
  70. logger.debug(f"After run - Memory buffer: {memory.buffer}")
  71. response = response[:100] if len(response) > 100 else response
  72. return response
  73. def ask_rag(session_id: str, question: str, file_content: bytes = None, file_name: str = None) -> dict:
  74. if file_content and file_name:
  75. index_file(file_content, file_name)
  76. memory = session_memories.get(session_id, ConversationBufferWindowMemory(k=3))
  77. if session_id not in session_memories:
  78. session_memories[session_id] = memory
  79. logger.debug(f"New memory assigned for session_id: {session_id}")
  80. docs = vector_store.similarity_search(question, k=3)
  81. prompt = get_prompt_with_history_and_docs(session_id, docs)
  82. chain = LLMChain(llm=llm, prompt=prompt, memory=memory)
  83. logger.debug(f"Before run - Memory buffer: {memory.buffer}")
  84. response = chain.run(question=question)
  85. logger.debug(f"After run - Memory buffer: {memory.buffer}")
  86. response = response[:100] if len(response) > 100 else response
  87. sources = [doc.page_content for doc in docs]
  88. return {"answer": response, "sources": sources}
  89. if __name__ == "__main__":
  90. session_id = "test_session"
  91. print(get_answer(session_id, "Qual a capital da França?"))
  92. print(get_answer(session_id, "E a capital da Espanha?"))