浏览代码

updating endpoints

galo 3 月之前
父节点
当前提交
9e3aef9816
共有 3 个文件被更改,包括 14 次插入5 次删除
  1. 1 1
      README.md
  2. 4 4
      app/api/chat.py
  3. 9 0
      app/schemas/chat.py

+ 1 - 1
README.md

@@ -26,5 +26,5 @@ curl -X POST http://localhost:8000/chat/?session_id=test_session -H "Content-Typ
 
 ou
 
-curl -X POST "http://localhost:8000/ask/?session_id=test_session&question=Qual%20a%20capital%20da%20Fran%C3%A7a%3F" -F "file=@E:/test.txt"
+curl -X POST http://localhost:8000/ask/ -F "session_id=test_session" -F "question=What is the capital of France?" -F "file=@E:/test.txt"
 ```

+ 4 - 4
app/api/chat.py

@@ -1,6 +1,6 @@
-from fastapi import APIRouter, HTTPException, UploadFile, File, Query
+from fastapi import APIRouter, HTTPException, UploadFile, File, Form
 from app.services.qa import get_answer, ask_rag
-from app.schemas.chat import ChatRequest, ChatResponse
+from app.schemas.chat import ChatRequest, AskRequest, ChatResponse, AskResponse
 import time
 
 router = APIRouter()
@@ -17,7 +17,7 @@ async def chat(request: ChatRequest, session_id: str = "default_session"):
         raise HTTPException(status_code=500, detail=f"Ollama call failed with error: {str(e)}")
 
 @router.post("/ask/")
-async def ask(session_id: str = Query("default_session"), question: str = Query(...), file: UploadFile = File(None)):
+async def ask(session_id: str = Form("default_session"), question: str = Form(...), file: UploadFile = File(None)):
     start_time = time.time()
     
     try:
@@ -31,7 +31,7 @@ async def ask(session_id: str = Query("default_session"), question: str = Query(
         
         result = ask_rag(session_id, question, file_content, file_name)
         latency_ms = int((time.time() - start_time) * 1000)
-        return {"answer": result["answer"], "sources": result["sources"], "latency_ms": latency_ms}
+        return AskResponse(answer=result["answer"], sources=result["sources"], latency_ms=latency_ms)
     except Exception as e:
         raise HTTPException(status_code=500, detail=f"Ollama call failed with error: {str(e)}")
 

+ 9 - 0
app/schemas/chat.py

@@ -3,6 +3,15 @@ from pydantic import BaseModel
 class ChatRequest(BaseModel):
     message: str
 
+class AskRequest(BaseModel):
+    question: str
+
 class ChatResponse(BaseModel):
     answer: str
+    latency_ms: int
+
+# Optional: Define AskResponse for consistency (though not currently used)
+class AskResponse(BaseModel):
+    answer: str
+    sources: list[str]
     latency_ms: int