qa.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. from langchain_community.llms import Ollama
  2. from langchain.chains import LLMChain
  3. from langchain.prompts import PromptTemplate
  4. import os
  5. from dotenv import load_dotenv
  6. # Load environment variables
  7. load_dotenv()
  8. # Global dictionary to store chat history (in-memory for simplicity)
  9. chat_history = {}
  10. # Initialize Ollama LLM with tinyllama
  11. llm = Ollama(
  12. model="tinyllama",
  13. temperature=0.7
  14. )
  15. # Define a prompt template that includes history
  16. def get_prompt_with_history(session_id):
  17. history = chat_history.get(session_id, [])
  18. history_text = "\n".join([f"User: {msg['question']}\nAI: {msg['answer']}" for msg in history]) if history else "No previous conversation."
  19. return PromptTemplate(
  20. input_variables=["question"],
  21. template=f"Previous conversation:\n{history_text}\n\nResponda à seguinte pergunta: {{question}}"
  22. )
  23. def get_answer(session_id: str, question: str) -> str:
  24. # Get or initialize chat history for this session
  25. if session_id not in chat_history:
  26. chat_history[session_id] = []
  27. # Create chain with dynamic prompt including history
  28. prompt = get_prompt_with_history(session_id)
  29. chain = LLMChain(llm=llm, prompt=prompt)
  30. # Get response
  31. response = chain.run(question=question)
  32. response = response[:100] if len(response) > 100 else response # Truncate if needed
  33. # Store the interaction in history
  34. chat_history[session_id].append({"question": question, "answer": response})
  35. return response
  36. if __name__ == "__main__":
  37. session_id = "test_session"
  38. print(get_answer(session_id, "Qual a capital da França?"))
  39. print(get_answer(session_id, "E a capital da Espanha?"))