| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- from langchain_community.llms import Ollama
- from langchain.chains import LLMChain
- from langchain.prompts import PromptTemplate
- from langchain_community.vectorstores import Chroma
- from langchain_community.embeddings import OllamaEmbeddings
- from langchain_community.document_loaders import TextLoader
- from langchain_text_splitters import CharacterTextSplitter
- import os
- from dotenv import load_dotenv
- from tempfile import NamedTemporaryFile
- # Load environment variables
- load_dotenv()
- # Global dictionary for chat history
- chat_history = {}
- # Initialize Ollama LLM and Embeddings
- llm = Ollama(model="tinyllama", temperature=0.7)
- embeddings = OllamaEmbeddings(model="tinyllama")
- # Initialize global Chroma vector store (in-memory)
- vector_store = Chroma.from_texts([""], embeddings) # Initialize empty store
- # Function to index uploaded file
- def index_file(file_content: bytes, file_name: str):
- with NamedTemporaryFile(delete=False, suffix=os.path.splitext(file_name)[1]) as temp_file:
- temp_file.write(file_content)
- temp_file_path = temp_file.name
- loader = TextLoader(temp_file_path)
- documents = loader.load()
- # Split documents into chunks
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
- chunks = text_splitter.split_documents(documents)
- # Add to vector store
- vector_store.add_documents(chunks)
- # Clean up temp file
- os.unlink(temp_file_path)
- # Define prompt templates
- def get_prompt_with_history(session_id):
- history = chat_history.get(session_id, [])
- history_text = "\n".join([f"User: {msg['question']}\nAI: {msg['answer']}" for msg in history]) if history else "No previous conversation."
- return PromptTemplate(
- input_variables=["question"],
- template=f"Previous conversation:\n{history_text}\n\nResponda à seguinte pergunta: {{question}}"
- )
- def get_prompt_with_history_and_docs(session_id, docs):
- history = chat_history.get(session_id, [])
- history_text = "\n".join([f"User: {msg['question']}\nAI: {msg['answer']}" for msg in history]) if history else "No previous conversation."
- docs_text = "\n".join([f"Source: {doc.page_content}" for doc in docs]) if docs else "No relevant documents found."
- return PromptTemplate(
- input_variables=["question"],
- template=f"Previous conversation:\n{history_text}\n\nRelevant documents:\n{docs_text}\n\nResponda à seguinte pergunta usando as fontes relevantes e citando trechos como fontes: {{question}}"
- )
- def get_answer(session_id: str, question: str) -> str:
- # Get or initialize chat history for this session
- if session_id not in chat_history:
- chat_history[session_id] = []
-
- # Create chain with dynamic prompt including history
- prompt = get_prompt_with_history(session_id)
- chain = LLMChain(llm=llm, prompt=prompt)
-
- # Get response
- response = chain.run(question=question)
- response = response[:100] if len(response) > 100 else response # Truncate if needed
-
- # Store the interaction in history
- chat_history[session_id].append({"question": question, "answer": response})
-
- return response
- # RAG function for /ask endpoint
- def ask_rag(session_id: str, question: str, file_content: bytes = None, file_name: str = None) -> dict:
- if file_content and file_name:
- index_file(file_content, file_name)
-
- if session_id not in chat_history:
- chat_history[session_id] = []
-
- docs = vector_store.similarity_search(question, k=3)
-
- prompt = get_prompt_with_history_and_docs(session_id, docs)
- chain = LLMChain(llm=llm, prompt=prompt)
-
- response = chain.run(question=question)
- response = response[:100] if len(response) > 100 else response
-
- chat_history[session_id].append({"question": question, "answer": response})
-
- sources = [doc.page_content for doc in docs]
-
- return {"answer": response, "sources": sources}
- if __name__ == "__main__":
- session_id = "test_session"
- print(get_answer(session_id, "Qual a capital da França?"))
- print(get_answer(session_id, "E a capital da Espanha?"))
|