diff options
| author | Fuwn <[email protected]> | 2025-07-29 11:39:41 +0200 |
|---|---|---|
| committer | Fuwn <[email protected]> | 2025-07-29 11:39:41 +0200 |
| commit | 3d72b0f85f7a14c675d543e614f8a9c617d550eb (patch) | |
| tree | 9d50aa12b6b0a847899bb0a8a96090d53ef68436 | |
| parent | feat(umapyai): Improve system prompt self definition (diff) | |
| download | umapyai-3d72b0f85f7a14c675d543e614f8a9c617d550eb.tar.xz umapyai-3d72b0f85f7a14c675d543e614f8a9c617d550eb.zip | |
feat(umapyai): Response streaming
| -rw-r--r-- | pyproject.toml | 1 | ||||
| -rw-r--r-- | requirements-dev.lock | 8 | ||||
| -rw-r--r-- | requirements.lock | 8 | ||||
| -rw-r--r-- | src/umapyai/__init__.py | 83 | ||||
| -rw-r--r-- | src/umapyai/chat.html | 40 |
5 files changed, 93 insertions, 47 deletions
diff --git a/pyproject.toml b/pyproject.toml index 49fde4f..d0c7537 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "beautifulsoup4>=4.13.4", "flask>=3.1.1", "flask-cors>=6.0.1", + "flask-sock>=0.7.0", ] readme = "README.md" requires-python = ">= 3.8" diff --git a/requirements-dev.lock b/requirements-dev.lock index c1671dc..7f07737 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -55,9 +55,12 @@ filelock==3.18.0 # via transformers flask==3.1.1 # via flask-cors + # via flask-sock # via umapyai flask-cors==6.0.1 # via umapyai +flask-sock==0.7.0 + # via umapyai flatbuffers==25.2.10 # via onnxruntime fsspec==2025.7.0 @@ -73,6 +76,7 @@ grpcio==1.74.0 h11==0.16.0 # via httpcore # via uvicorn + # via wsproto hf-xet==1.1.5 # via huggingface-hub httpcore==1.0.9 @@ -237,6 +241,8 @@ setuptools==80.9.0 # via torch shellingham==1.5.4 # via typer +simple-websocket==1.1.0 + # via flask-sock six==1.17.0 # via kubernetes # via posthog @@ -298,6 +304,8 @@ websockets==15.0.1 werkzeug==3.1.3 # via flask # via flask-cors +wsproto==1.2.0 + # via simple-websocket yapf==0.43.0 zipp==3.23.0 # via importlib-metadata diff --git a/requirements.lock b/requirements.lock index ef68198..ab8398b 100644 --- a/requirements.lock +++ b/requirements.lock @@ -55,9 +55,12 @@ filelock==3.18.0 # via transformers flask==3.1.1 # via flask-cors + # via flask-sock # via umapyai flask-cors==6.0.1 # via umapyai +flask-sock==0.7.0 + # via umapyai flatbuffers==25.2.10 # via onnxruntime fsspec==2025.7.0 @@ -73,6 +76,7 @@ grpcio==1.74.0 h11==0.16.0 # via httpcore # via uvicorn + # via wsproto hf-xet==1.1.5 # via huggingface-hub httpcore==1.0.9 @@ -234,6 +238,8 @@ setuptools==80.9.0 # via torch shellingham==1.5.4 # via typer +simple-websocket==1.1.0 + # via flask-sock six==1.17.0 # via kubernetes # via posthog @@ -295,5 +301,7 @@ websockets==15.0.1 werkzeug==3.1.3 # via flask # via flask-cors +wsproto==1.2.0 + # via simple-websocket zipp==3.23.0 # via importlib-metadata diff --git a/src/umapyai/__init__.py b/src/umapyai/__init__.py index 2aab5fe..348fd23 100644 --- a/src/umapyai/__init__.py +++ b/src/umapyai/__init__.py @@ -1,12 +1,14 @@ import os import sys +import json import chromadb from sentence_transformers import SentenceTransformer import requests from loguru import logger from threading import Thread -from flask import Flask, request, jsonify, send_file +from flask import Flask, send_file from flask_cors import CORS +from flask_sock import Sock from .constants import (ARTICLES_DIRECTORY, CHROMA_DIRECTORY, CHROMA_COLLECTION, CHUNK_SIZE, EMBEDDING_MODEL, OLLAMA_MODEL, TOP_K, OLLAMA_URL) @@ -24,6 +26,8 @@ app = Flask(__name__) CORS(app) +socket = Sock(app) + def prompt(rag_context, user_query, is_first_turn=True): if is_first_turn: @@ -63,23 +67,23 @@ def start_flask(find_relevant_chunks, query_ollama): def index(): return send_file("chat.html") - @app.route("/api/ask", methods=["POST"]) - def api_ask(): - data = request.get_json() - user_query = data.get("question", "") - history_context = data.get("history", None) - top_chunks = find_relevant_chunks(user_query) - rag_context = "\n\n".join([c[0] for c in top_chunks]) - full_prompt = prompt( - rag_context, user_query, is_first_turn=(history_context is None)) - answer, new_history_context = query_ollama(full_prompt, history_context) - sources = ", ".join(sorted(set(meta['source'] for _, meta in top_chunks))) - - return jsonify({ - "answer": answer, - "sources": sources, - "history": new_history_context - }) + @socket.route("/api/ask") + def api_ask(webSocket): + while True: + data = json.loads(webSocket.receive()) + user_query = data.get("question", "") + history_context = data.get("history", None) + top_chunks = find_relevant_chunks(user_query) + rag_context = "\n\n".join([chunk[0] for chunk in top_chunks]) + full_prompt = prompt( + rag_context, user_query, is_first_turn=(history_context is None)) + sources = ", ".join( + sorted(set(metadata['source'] for _, metadata in top_chunks))) + + webSocket.send(json.dumps({"type": "sources", "data": sources})) + + for ollama_response in query_ollama(full_prompt, history_context): + webSocket.send(json.dumps(ollama_response)) app.run(host="0.0.0.0", port=5000, debug=False, use_reloader=False) @@ -161,28 +165,34 @@ def main(): payload = { "model": OLLAMA_MODEL, "prompt": prompt, - "stream": False, + "stream": True, } if context: payload["context"] = context try: - response = requests.post(url, json=payload) + response = requests.post(url, json=payload, stream=True) response.raise_for_status() - json_response = response.json() - answer = json_response.get('response', '').strip() - new_context = json_response.get('context') - - return answer, new_context + for line in response.iter_lines(): + if line: + json_response = json.loads(line) + + if not json_response.get("done"): + yield { + "type": "answer_chunk", + "data": json_response.get("response", "") + } + else: + yield {"type": "history", "data": json_response.get("context")} except Exception as error: error_message = f"Error communicating with Ollama: {error}" logger.error(error_message) - return error_message, None + yield {"type": "error", "data": error_message} flask_thread = Thread( target=start_flask, @@ -203,16 +213,25 @@ def main(): break top_chunks = find_relevant_chunks(user_query) - rag_context = "\n\n".join([c[0] for c in top_chunks]) + rag_context = "\n\n".join([chunk[0] for chunk in top_chunks]) full_prompt = prompt( rag_context, user_query, is_first_turn=(cli_history_context is None)) - answer, new_cli_history_context = query_ollama(full_prompt, - cli_history_context) - cli_history_context = new_cli_history_context - print("\n", answer) + print("\n") + + full_answer = "" + + for ollama_response in query_ollama(full_prompt, cli_history_context): + if ollama_response["type"] == "answer_chunk": + chunk = ollama_response["data"] + full_answer += chunk + + print(chunk, end="", flush=True) + elif ollama_response["type"] == "history": + cli_history_context = ollama_response["data"] + print( - "\nSources:", ", ".join( + "\n\nSources:", ", ".join( sorted(set(metadata['source'] for _, metadata in top_chunks)))) finally: if started_ollama and ollama_process is not None: diff --git a/src/umapyai/chat.html b/src/umapyai/chat.html index d5f50c2..d94491a 100644 --- a/src/umapyai/chat.html +++ b/src/umapyai/chat.html @@ -185,6 +185,24 @@ let sendButton = document.getElementById("send-button"); let chat = []; let historyContext = null; + let currentAIMessage = null; + const webSocket = new WebSocket(`ws://${window.location.host}/api/ask`); + + webSocket.onmessage = function (event) { + const response = JSON.parse(event.data); + + if (response.type === "sources") { + currentAIMessage.sources = response.data; + } else if (response.type === "answer_chunk") { + currentAIMessage.text += response.data; + } else if (response.type === "history") { + historyContext = response.data; + } else if (response.type === "error") { + currentAIMessage.text += "\n\nError: " + response.data; + } + + render(); + }; const render = () => { chatbox.innerHTML = ""; @@ -211,25 +229,17 @@ if (!query) return; chat.push({ user: 1, text: query }); - render(); - prompt.value = ""; + currentAIMessage = { user: 0, text: "", sources: "" }; - let response = await fetch("/api/ask", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ question: query, history: historyContext }), - }); - let responseData = await response.json(); + chat.push(currentAIMessage); + render(); - historyContext = responseData.history; + prompt.value = ""; - chat.push({ - user: 0, - text: responseData.answer, - sources: responseData.sources, - }); - render(); + webSocket.send( + JSON.stringify({ question: query, history: historyContext }), + ); }; prompt.addEventListener("keydown", (event) => { |