qa.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from langchain_community.llms import Ollama
  2. from langchain.chains import LLMChain
  3. from langchain.prompts import PromptTemplate
  4. from langchain_community.vectorstores import Chroma
  5. from langchain_community.embeddings import OllamaEmbeddings
  6. from langchain_community.document_loaders import TextLoader
  7. from langchain_text_splitters import CharacterTextSplitter
  8. import os
  9. from dotenv import load_dotenv
  10. from tempfile import NamedTemporaryFile
  11. # Load environment variables
  12. load_dotenv()
  13. # Global dictionary for chat history
  14. chat_history = {}
  15. # Initialize Ollama LLM and Embeddings
  16. llm = Ollama(model="tinyllama", temperature=0.7)
  17. embeddings = OllamaEmbeddings(model="tinyllama")
  18. # Initialize global Chroma vector store (in-memory)
  19. vector_store = Chroma.from_texts([""], embeddings) # Initialize empty store
  20. # Function to index uploaded file
  21. def index_file(file_content: bytes, file_name: str):
  22. with NamedTemporaryFile(delete=False, suffix=os.path.splitext(file_name)[1]) as temp_file:
  23. temp_file.write(file_content)
  24. temp_file_path = temp_file.name
  25. loader = TextLoader(temp_file_path)
  26. documents = loader.load()
  27. # Split documents into chunks
  28. text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
  29. chunks = text_splitter.split_documents(documents)
  30. # Add to vector store
  31. vector_store.add_documents(chunks)
  32. # Clean up temp file
  33. os.unlink(temp_file_path)
  34. # Define prompt templates
  35. def get_prompt_with_history(session_id):
  36. history = chat_history.get(session_id, [])
  37. history_text = "\n".join([f"User: {msg['question']}\nAI: {msg['answer']}" for msg in history]) if history else "No previous conversation."
  38. return PromptTemplate(
  39. input_variables=["question"],
  40. template=f"Previous conversation:\n{history_text}\n\nResponda à seguinte pergunta: {{question}}"
  41. )
  42. def get_prompt_with_history_and_docs(session_id, docs):
  43. history = chat_history.get(session_id, [])
  44. history_text = "\n".join([f"User: {msg['question']}\nAI: {msg['answer']}" for msg in history]) if history else "No previous conversation."
  45. docs_text = "\n".join([f"Source: {doc.page_content}" for doc in docs]) if docs else "No relevant documents found."
  46. return PromptTemplate(
  47. input_variables=["question"],
  48. template=f"Previous conversation:\n{history_text}\n\nRelevant documents:\n{docs_text}\n\nResponda à seguinte pergunta usando as fontes relevantes e citando trechos como fontes: {{question}}"
  49. )
  50. def get_answer(session_id: str, question: str) -> str:
  51. # Get or initialize chat history for this session
  52. if session_id not in chat_history:
  53. chat_history[session_id] = []
  54. # Create chain with dynamic prompt including history
  55. prompt = get_prompt_with_history(session_id)
  56. chain = LLMChain(llm=llm, prompt=prompt)
  57. # Get response
  58. response = chain.run(question=question)
  59. response = response[:100] if len(response) > 100 else response # Truncate if needed
  60. # Store the interaction in history
  61. chat_history[session_id].append({"question": question, "answer": response})
  62. return response
  63. # RAG function for /ask endpoint
  64. def ask_rag(session_id: str, question: str, file_content: bytes = None, file_name: str = None) -> dict:
  65. if file_content and file_name:
  66. index_file(file_content, file_name)
  67. if session_id not in chat_history:
  68. chat_history[session_id] = []
  69. docs = vector_store.similarity_search(question, k=3)
  70. prompt = get_prompt_with_history_and_docs(session_id, docs)
  71. chain = LLMChain(llm=llm, prompt=prompt)
  72. response = chain.run(question=question)
  73. response = response[:100] if len(response) > 100 else response
  74. chat_history[session_id].append({"question": question, "answer": response})
  75. sources = [doc.page_content for doc in docs]
  76. return {"answer": response, "sources": sources}
  77. if __name__ == "__main__":
  78. session_id = "test_session"
  79. print(get_answer(session_id, "Qual a capital da França?"))
  80. print(get_answer(session_id, "E a capital da Espanha?"))