|
|
@@ -7,23 +7,43 @@ from dotenv import load_dotenv
|
|
|
# Load environment variables
|
|
|
load_dotenv()
|
|
|
|
|
|
-# Initialize Ollama LLM
|
|
|
+# Global dictionary to store chat history (in-memory for simplicity)
|
|
|
+chat_history = {}
|
|
|
+
|
|
|
+# Initialize Ollama LLM with tinyllama
|
|
|
llm = Ollama(
|
|
|
- model="mistral",
|
|
|
+ model="tinyllama",
|
|
|
temperature=0.7
|
|
|
)
|
|
|
|
|
|
-# Define a simple prompt template
|
|
|
-prompt = PromptTemplate(
|
|
|
- input_variables=["question"],
|
|
|
- template="Responda à seguinte pergunta: {question}"
|
|
|
-)
|
|
|
-
|
|
|
-# Create the LLM chain
|
|
|
-chain = LLMChain(llm=llm, prompt=prompt)
|
|
|
+# Define a prompt template that includes history
|
|
|
+def get_prompt_with_history(session_id):
|
|
|
+ history = chat_history.get(session_id, [])
|
|
|
+ history_text = "\n".join([f"User: {msg['question']}\nAI: {msg['answer']}" for msg in history]) if history else "No previous conversation."
|
|
|
+ return PromptTemplate(
|
|
|
+ input_variables=["question"],
|
|
|
+ template=f"Previous conversation:\n{history_text}\n\nResponda à seguinte pergunta: {{question}}"
|
|
|
+ )
|
|
|
|
|
|
-def get_answer(question: str) -> str:
|
|
|
- return chain.run(question=question)
|
|
|
+def get_answer(session_id: str, question: str) -> str:
|
|
|
+ # Get or initialize chat history for this session
|
|
|
+ if session_id not in chat_history:
|
|
|
+ chat_history[session_id] = []
|
|
|
+
|
|
|
+ # Create chain with dynamic prompt including history
|
|
|
+ prompt = get_prompt_with_history(session_id)
|
|
|
+ chain = LLMChain(llm=llm, prompt=prompt)
|
|
|
+
|
|
|
+ # Get response
|
|
|
+ response = chain.run(question=question)
|
|
|
+ response = response[:100] if len(response) > 100 else response # Truncate if needed
|
|
|
+
|
|
|
+ # Store the interaction in history
|
|
|
+ chat_history[session_id].append({"question": question, "answer": response})
|
|
|
+
|
|
|
+ return response
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- print(get_answer("Qual a capital da França?"))
|
|
|
+ session_id = "test_session"
|
|
|
+ print(get_answer(session_id, "Qual a capital da França?"))
|
|
|
+ print(get_answer(session_id, "E a capital da Espanha?"))
|