aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFuwn <[email protected]>2025-07-29 11:39:41 +0200
committerFuwn <[email protected]>2025-07-29 11:39:41 +0200
commit3d72b0f85f7a14c675d543e614f8a9c617d550eb (patch)
tree9d50aa12b6b0a847899bb0a8a96090d53ef68436
parentfeat(umapyai): Improve system prompt self definition (diff)
downloadumapyai-3d72b0f85f7a14c675d543e614f8a9c617d550eb.tar.xz
umapyai-3d72b0f85f7a14c675d543e614f8a9c617d550eb.zip
feat(umapyai): Response streaming
-rw-r--r--pyproject.toml1
-rw-r--r--requirements-dev.lock8
-rw-r--r--requirements.lock8
-rw-r--r--src/umapyai/__init__.py83
-rw-r--r--src/umapyai/chat.html40
5 files changed, 93 insertions, 47 deletions
diff --git a/pyproject.toml b/pyproject.toml
index 49fde4f..d0c7537 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -12,6 +12,7 @@ dependencies = [
"beautifulsoup4>=4.13.4",
"flask>=3.1.1",
"flask-cors>=6.0.1",
+ "flask-sock>=0.7.0",
]
readme = "README.md"
requires-python = ">= 3.8"
diff --git a/requirements-dev.lock b/requirements-dev.lock
index c1671dc..7f07737 100644
--- a/requirements-dev.lock
+++ b/requirements-dev.lock
@@ -55,9 +55,12 @@ filelock==3.18.0
# via transformers
flask==3.1.1
# via flask-cors
+ # via flask-sock
# via umapyai
flask-cors==6.0.1
# via umapyai
+flask-sock==0.7.0
+ # via umapyai
flatbuffers==25.2.10
# via onnxruntime
fsspec==2025.7.0
@@ -73,6 +76,7 @@ grpcio==1.74.0
h11==0.16.0
# via httpcore
# via uvicorn
+ # via wsproto
hf-xet==1.1.5
# via huggingface-hub
httpcore==1.0.9
@@ -237,6 +241,8 @@ setuptools==80.9.0
# via torch
shellingham==1.5.4
# via typer
+simple-websocket==1.1.0
+ # via flask-sock
six==1.17.0
# via kubernetes
# via posthog
@@ -298,6 +304,8 @@ websockets==15.0.1
werkzeug==3.1.3
# via flask
# via flask-cors
+wsproto==1.2.0
+ # via simple-websocket
yapf==0.43.0
zipp==3.23.0
# via importlib-metadata
diff --git a/requirements.lock b/requirements.lock
index ef68198..ab8398b 100644
--- a/requirements.lock
+++ b/requirements.lock
@@ -55,9 +55,12 @@ filelock==3.18.0
# via transformers
flask==3.1.1
# via flask-cors
+ # via flask-sock
# via umapyai
flask-cors==6.0.1
# via umapyai
+flask-sock==0.7.0
+ # via umapyai
flatbuffers==25.2.10
# via onnxruntime
fsspec==2025.7.0
@@ -73,6 +76,7 @@ grpcio==1.74.0
h11==0.16.0
# via httpcore
# via uvicorn
+ # via wsproto
hf-xet==1.1.5
# via huggingface-hub
httpcore==1.0.9
@@ -234,6 +238,8 @@ setuptools==80.9.0
# via torch
shellingham==1.5.4
# via typer
+simple-websocket==1.1.0
+ # via flask-sock
six==1.17.0
# via kubernetes
# via posthog
@@ -295,5 +301,7 @@ websockets==15.0.1
werkzeug==3.1.3
# via flask
# via flask-cors
+wsproto==1.2.0
+ # via simple-websocket
zipp==3.23.0
# via importlib-metadata
diff --git a/src/umapyai/__init__.py b/src/umapyai/__init__.py
index 2aab5fe..348fd23 100644
--- a/src/umapyai/__init__.py
+++ b/src/umapyai/__init__.py
@@ -1,12 +1,14 @@
import os
import sys
+import json
import chromadb
from sentence_transformers import SentenceTransformer
import requests
from loguru import logger
from threading import Thread
-from flask import Flask, request, jsonify, send_file
+from flask import Flask, send_file
from flask_cors import CORS
+from flask_sock import Sock
from .constants import (ARTICLES_DIRECTORY, CHROMA_DIRECTORY, CHROMA_COLLECTION,
CHUNK_SIZE, EMBEDDING_MODEL, OLLAMA_MODEL, TOP_K,
OLLAMA_URL)
@@ -24,6 +26,8 @@ app = Flask(__name__)
CORS(app)
+socket = Sock(app)
+
def prompt(rag_context, user_query, is_first_turn=True):
if is_first_turn:
@@ -63,23 +67,23 @@ def start_flask(find_relevant_chunks, query_ollama):
def index():
return send_file("chat.html")
- @app.route("/api/ask", methods=["POST"])
- def api_ask():
- data = request.get_json()
- user_query = data.get("question", "")
- history_context = data.get("history", None)
- top_chunks = find_relevant_chunks(user_query)
- rag_context = "\n\n".join([c[0] for c in top_chunks])
- full_prompt = prompt(
- rag_context, user_query, is_first_turn=(history_context is None))
- answer, new_history_context = query_ollama(full_prompt, history_context)
- sources = ", ".join(sorted(set(meta['source'] for _, meta in top_chunks)))
-
- return jsonify({
- "answer": answer,
- "sources": sources,
- "history": new_history_context
- })
+ @socket.route("/api/ask")
+ def api_ask(webSocket):
+ while True:
+ data = json.loads(webSocket.receive())
+ user_query = data.get("question", "")
+ history_context = data.get("history", None)
+ top_chunks = find_relevant_chunks(user_query)
+ rag_context = "\n\n".join([chunk[0] for chunk in top_chunks])
+ full_prompt = prompt(
+ rag_context, user_query, is_first_turn=(history_context is None))
+ sources = ", ".join(
+ sorted(set(metadata['source'] for _, metadata in top_chunks)))
+
+ webSocket.send(json.dumps({"type": "sources", "data": sources}))
+
+ for ollama_response in query_ollama(full_prompt, history_context):
+ webSocket.send(json.dumps(ollama_response))
app.run(host="0.0.0.0", port=5000, debug=False, use_reloader=False)
@@ -161,28 +165,34 @@ def main():
payload = {
"model": OLLAMA_MODEL,
"prompt": prompt,
- "stream": False,
+ "stream": True,
}
if context:
payload["context"] = context
try:
- response = requests.post(url, json=payload)
+ response = requests.post(url, json=payload, stream=True)
response.raise_for_status()
- json_response = response.json()
- answer = json_response.get('response', '').strip()
- new_context = json_response.get('context')
-
- return answer, new_context
+ for line in response.iter_lines():
+ if line:
+ json_response = json.loads(line)
+
+ if not json_response.get("done"):
+ yield {
+ "type": "answer_chunk",
+ "data": json_response.get("response", "")
+ }
+ else:
+ yield {"type": "history", "data": json_response.get("context")}
except Exception as error:
error_message = f"Error communicating with Ollama: {error}"
logger.error(error_message)
- return error_message, None
+ yield {"type": "error", "data": error_message}
flask_thread = Thread(
target=start_flask,
@@ -203,16 +213,25 @@ def main():
break
top_chunks = find_relevant_chunks(user_query)
- rag_context = "\n\n".join([c[0] for c in top_chunks])
+ rag_context = "\n\n".join([chunk[0] for chunk in top_chunks])
full_prompt = prompt(
rag_context, user_query, is_first_turn=(cli_history_context is None))
- answer, new_cli_history_context = query_ollama(full_prompt,
- cli_history_context)
- cli_history_context = new_cli_history_context
- print("\n", answer)
+ print("\n")
+
+ full_answer = ""
+
+ for ollama_response in query_ollama(full_prompt, cli_history_context):
+ if ollama_response["type"] == "answer_chunk":
+ chunk = ollama_response["data"]
+ full_answer += chunk
+
+ print(chunk, end="", flush=True)
+ elif ollama_response["type"] == "history":
+ cli_history_context = ollama_response["data"]
+
print(
- "\nSources:", ", ".join(
+ "\n\nSources:", ", ".join(
sorted(set(metadata['source'] for _, metadata in top_chunks))))
finally:
if started_ollama and ollama_process is not None:
diff --git a/src/umapyai/chat.html b/src/umapyai/chat.html
index d5f50c2..d94491a 100644
--- a/src/umapyai/chat.html
+++ b/src/umapyai/chat.html
@@ -185,6 +185,24 @@
let sendButton = document.getElementById("send-button");
let chat = [];
let historyContext = null;
+ let currentAIMessage = null;
+ const webSocket = new WebSocket(`ws://${window.location.host}/api/ask`);
+
+ webSocket.onmessage = function (event) {
+ const response = JSON.parse(event.data);
+
+ if (response.type === "sources") {
+ currentAIMessage.sources = response.data;
+ } else if (response.type === "answer_chunk") {
+ currentAIMessage.text += response.data;
+ } else if (response.type === "history") {
+ historyContext = response.data;
+ } else if (response.type === "error") {
+ currentAIMessage.text += "\n\nError: " + response.data;
+ }
+
+ render();
+ };
const render = () => {
chatbox.innerHTML = "";
@@ -211,25 +229,17 @@
if (!query) return;
chat.push({ user: 1, text: query });
- render();
- prompt.value = "";
+ currentAIMessage = { user: 0, text: "", sources: "" };
- let response = await fetch("/api/ask", {
- method: "POST",
- headers: { "Content-Type": "application/json" },
- body: JSON.stringify({ question: query, history: historyContext }),
- });
- let responseData = await response.json();
+ chat.push(currentAIMessage);
+ render();
- historyContext = responseData.history;
+ prompt.value = "";
- chat.push({
- user: 0,
- text: responseData.answer,
- sources: responseData.sources,
- });
- render();
+ webSocket.send(
+ JSON.stringify({ question: query, history: historyContext }),
+ );
};
prompt.addEventListener("keydown", (event) => {