|
|
@@ -1,6 +1,6 @@
|
|
|
-from fastapi import APIRouter, HTTPException, UploadFile, File, Query
|
|
|
+from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
|
|
from app.services.qa import get_answer, ask_rag
|
|
|
-from app.schemas.chat import ChatRequest, ChatResponse
|
|
|
+from app.schemas.chat import ChatRequest, AskRequest, ChatResponse, AskResponse
|
|
|
import time
|
|
|
|
|
|
router = APIRouter()
|
|
|
@@ -17,7 +17,7 @@ async def chat(request: ChatRequest, session_id: str = "default_session"):
|
|
|
raise HTTPException(status_code=500, detail=f"Ollama call failed with error: {str(e)}")
|
|
|
|
|
|
@router.post("/ask/")
|
|
|
-async def ask(session_id: str = Query("default_session"), question: str = Query(...), file: UploadFile = File(None)):
|
|
|
+async def ask(session_id: str = Form("default_session"), question: str = Form(...), file: UploadFile = File(None)):
|
|
|
start_time = time.time()
|
|
|
|
|
|
try:
|
|
|
@@ -31,7 +31,7 @@ async def ask(session_id: str = Query("default_session"), question: str = Query(
|
|
|
|
|
|
result = ask_rag(session_id, question, file_content, file_name)
|
|
|
latency_ms = int((time.time() - start_time) * 1000)
|
|
|
- return {"answer": result["answer"], "sources": result["sources"], "latency_ms": latency_ms}
|
|
|
+ return AskResponse(answer=result["answer"], sources=result["sources"], latency_ms=latency_ms)
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=500, detail=f"Ollama call failed with error: {str(e)}")
|
|
|
|