qa.py 3.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from langchain_community.llms import Ollama
  2. from langchain.chains import LLMChain
  3. from langchain.prompts import PromptTemplate
  4. from langchain_community.vectorstores import Chroma
  5. from langchain_community.embeddings import OllamaEmbeddings
  6. from langchain_community.document_loaders import TextLoader
  7. from langchain_text_splitters import CharacterTextSplitter
  8. from langchain.memory import ConversationBufferWindowMemory
  9. import os
  10. from dotenv import load_dotenv
  11. from tempfile import NamedTemporaryFile
  12. # Load environment variables
  13. load_dotenv()
  14. # Global dictionary to store memory instances per session
  15. session_memories = {}
  16. # Initialize Ollama LLM and Embeddings
  17. llm = Ollama(model="tinyllama", temperature=0.7)
  18. embeddings = OllamaEmbeddings(model="tinyllama")
  19. # Initialize global Chroma vector store (in-memory)
  20. vector_store = Chroma.from_texts([""], embeddings) # Initialize empty store
  21. # Function to index uploaded file
  22. def index_file(file_content: bytes, file_name: str):
  23. with NamedTemporaryFile(delete=False, suffix=os.path.splitext(file_name)[1]) as temp_file:
  24. temp_file.write(file_content)
  25. temp_file_path = temp_file.name
  26. loader = TextLoader(temp_file_path)
  27. documents = loader.load()
  28. text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
  29. chunks = text_splitter.split_documents(documents)
  30. vector_store.add_documents(chunks)
  31. os.unlink(temp_file_path)
  32. # Define prompt templates
  33. def get_prompt_with_history(session_id):
  34. memory = session_memories.get(session_id)
  35. if not memory:
  36. memory = ConversationBufferWindowMemory(k=3) # Keep last 3 turns
  37. session_memories[session_id] = memory
  38. return PromptTemplate(
  39. input_variables=["question", "history"],
  40. template=f"Previous conversation:\n{{history}}\n\nResponda à seguinte pergunta: {{question}}"
  41. )
  42. def get_prompt_with_history_and_docs(session_id, docs):
  43. memory = session_memories.get(session_id)
  44. if not memory:
  45. memory = ConversationBufferWindowMemory(k=3)
  46. session_memories[session_id] = memory
  47. history_text = memory.buffer if memory.buffer else "No previous conversation."
  48. docs_text = "\n".join([f"Source: {doc.page_content}" for doc in docs]) if docs else "No relevant documents found."
  49. return PromptTemplate(
  50. input_variables=["question", "history"],
  51. template=f"Previous conversation:\n{history_text}\n\nRelevant documents:\n{docs_text}\n\nResponda à seguinte pergunta usando as fontes relevantes e citando trechos como fontes: {{question}}"
  52. )
  53. def get_answer(session_id: str, question: str) -> str:
  54. memory = session_memories.get(session_id, ConversationBufferWindowMemory(k=3))
  55. prompt = get_prompt_with_history(session_id)
  56. chain = LLMChain(llm=llm, prompt=prompt, memory=memory)
  57. response = chain.run(question=question)
  58. response = response[:100] if len(response) > 100 else response
  59. return response
  60. def ask_rag(session_id: str, question: str, file_content: bytes = None, file_name: str = None) -> dict:
  61. if file_content and file_name:
  62. index_file(file_content, file_name)
  63. memory = session_memories.get(session_id, ConversationBufferWindowMemory(k=3))
  64. docs = vector_store.similarity_search(question, k=3)
  65. prompt = get_prompt_with_history_and_docs(session_id, docs)
  66. chain = LLMChain(llm=llm, prompt=prompt, memory=memory)
  67. response = chain.run(question=question)
  68. response = response[:100] if len(response) > 100 else response
  69. sources = [doc.page_content for doc in docs]
  70. return {"answer": response, "sources": sources}
  71. if __name__ == "__main__":
  72. session_id = "test_session"
  73. print(get_answer(session_id, "Qual a capital da França?"))
  74. print(get_answer(session_id, "E a capital da Espanha?"))