Browse Source

added session id

galo 3 months ago
parent
commit
6294c3425e
6 changed files with 57 additions and 33 deletions
  1. 1 1
      Dockerfile
  2. 2 6
      README.md
  3. 15 6
      app/api/chat.py
  4. 5 6
      app/main.py
  5. 1 1
      app/schemas/chat.py
  6. 33 13
      app/services/qa.py

+ 1 - 1
Dockerfile

@@ -21,4 +21,4 @@ COPY . .
 EXPOSE 8000
 
 # Command to start Ollama server, pull the model if needed, and run FastAPI
-CMD ["sh", "-c", "ollama serve & (ollama list | grep -q mistral || ollama pull mistral) && uvicorn app.main:app --host 0.0.0.0 --port 8000 & wait"]
+CMD ["sh", "-c", "ollama serve & (ollama list | grep -q tinyllama || ollama pull tinyllama) && uvicorn app.main:app --host 0.0.0.0 --port 8000 & wait"]

+ 2 - 6
README.md

@@ -18,13 +18,9 @@ cd LangChain
 ```
 3. 
 ```
-docker build -t chat-api .
+docker build -t chat-api . && docker run -p 8000:8000 --env-file .env chat-api
 ```
 4. 
 ```
-docker run -p 8000:8000 --env-file .env chat-api
-```
-5. 
-```
-curl -X POST http://localhost:8000/chat/ -H "Content-Type: application/json" -d "{\"message\": \"Qual a capital da França?\"}"
+curl -X POST http://localhost:8000/chat/?session_id=test_session -H "Content-Type: application/json" -d "{\"message\": \"Qual a capital da França?\"}"
 ```

+ 15 - 6
app/api/chat.py

@@ -1,16 +1,25 @@
 from fastapi import APIRouter, HTTPException
+from app.services.qa import get_answer, chat_history
 from app.schemas.chat import ChatRequest, ChatResponse
-from app.services.qa import get_answer
 import time
 
 router = APIRouter()
 
[email protected]("/", response_model=ChatResponse)
-async def chat(request: ChatRequest):
[email protected]("/chat/")
+async def chat(request: ChatRequest, session_id: str = "default_session"):
+    start_time = time.time()
+    
     try:
-        start_time = time.time()
-        answer = get_answer(request.message)
+        answer = get_answer(session_id, request.message)
         latency_ms = int((time.time() - start_time) * 1000)
         return ChatResponse(answer=answer, latency_ms=latency_ms)
     except Exception as e:
-        raise HTTPException(status_code=500, detail=str(e))
+        raise HTTPException(status_code=500, detail=f"Ollama call failed with error: {str(e)}")
+
[email protected]("/health")
+async def health():
+    return {"status": "healthy"}
+
[email protected]("/sessions")
+async def list_sessions():
+    return {"sessions": {sid: len(history) for sid, history in chat_history.items()}}

+ 5 - 6
app/main.py

@@ -1,12 +1,11 @@
 from fastapi import FastAPI
 from app.api.chat import router as chat_router
 
-app = FastAPI(title="Chat API with LangChain")
+app = FastAPI()
 
 # Include the chat router
-app.include_router(chat_router, prefix="/chat")
+app.include_router(chat_router)
 
-# Health check endpoint
[email protected]("/health")
-async def health_check():
-    return {"status": "healthy"}
+if __name__ == "__main__":
+    import uvicorn
+    uvicorn.run(app, host="0.0.0.0", port=8000)

+ 1 - 1
app/schemas/chat.py

@@ -5,4 +5,4 @@ class ChatRequest(BaseModel):
 
 class ChatResponse(BaseModel):
     answer: str
-    latency_ms: int
+    latency_ms: int

+ 33 - 13
app/services/qa.py

@@ -7,23 +7,43 @@ from dotenv import load_dotenv
 # Load environment variables
 load_dotenv()
 
-# Initialize Ollama LLM
+# Global dictionary to store chat history (in-memory for simplicity)
+chat_history = {}
+
+# Initialize Ollama LLM with tinyllama
 llm = Ollama(
-    model="mistral",
+    model="tinyllama",
     temperature=0.7
 )
 
-# Define a simple prompt template
-prompt = PromptTemplate(
-    input_variables=["question"],
-    template="Responda à seguinte pergunta: {question}"
-)
-
-# Create the LLM chain
-chain = LLMChain(llm=llm, prompt=prompt)
+# Define a prompt template that includes history
+def get_prompt_with_history(session_id):
+    history = chat_history.get(session_id, [])
+    history_text = "\n".join([f"User: {msg['question']}\nAI: {msg['answer']}" for msg in history]) if history else "No previous conversation."
+    return PromptTemplate(
+        input_variables=["question"],
+        template=f"Previous conversation:\n{history_text}\n\nResponda à seguinte pergunta: {{question}}"
+    )
 
-def get_answer(question: str) -> str:
-    return chain.run(question=question)
+def get_answer(session_id: str, question: str) -> str:
+    # Get or initialize chat history for this session
+    if session_id not in chat_history:
+        chat_history[session_id] = []
+    
+    # Create chain with dynamic prompt including history
+    prompt = get_prompt_with_history(session_id)
+    chain = LLMChain(llm=llm, prompt=prompt)
+    
+    # Get response
+    response = chain.run(question=question)
+    response = response[:100] if len(response) > 100 else response  # Truncate if needed
+    
+    # Store the interaction in history
+    chat_history[session_id].append({"question": question, "answer": response})
+    
+    return response
 
 if __name__ == "__main__":
-    print(get_answer("Qual a capital da França?"))
+    session_id = "test_session"
+    print(get_answer(session_id, "Qual a capital da França?"))
+    print(get_answer(session_id, "E a capital da Espanha?"))