diff options
| author | Fuwn <[email protected]> | 2025-07-30 10:14:44 +0200 |
|---|---|---|
| committer | Fuwn <[email protected]> | 2025-07-30 10:14:44 +0200 |
| commit | 3ef972ca814495211cf81ceb13c85dd0e05f0e4b (patch) | |
| tree | 0ae1b77d83450230035f3a0c6a084bb34265b40b | |
| parent | feat(umapyai): Swap default model (diff) | |
| download | umapyai-3ef972ca814495211cf81ceb13c85dd0e05f0e4b.tar.xz umapyai-3ef972ca814495211cf81ceb13c85dd0e05f0e4b.zip | |
feat(umapyai): Improve context handling
| -rw-r--r-- | src/umapyai/__init__.py | 77 | ||||
| -rw-r--r-- | src/umapyai/chat.html | 20 |
2 files changed, 55 insertions, 42 deletions
diff --git a/src/umapyai/__init__.py b/src/umapyai/__init__.py index a5b3c3d..e17b5ba 100644 --- a/src/umapyai/__init__.py +++ b/src/umapyai/__init__.py @@ -30,36 +30,31 @@ socket = Sock(app) CORS(app) -def prompt(rag_context, user_query, is_first_turn=True): - if is_first_turn: - system_prompt = ( - 'You are a friendly and expert "Uma Musume: Pretty Derby" build guide advisor. ' - 'Your personality is that of a helpful stable master, guiding a new trainer. ' - 'Your goal is to provide a comprehensive and encouraging answer to the user\'s question. ' - 'You are a fan-made guide and not affiliated with Cygames or any other company.' - ) - instruction = ( - "Carefully analyse the user's question and the provided context. " - "Synthesise the information from the context to form a coherent answer. " - "Connect different pieces of information and draw logical conclusions from the text. " - "If the context does not contain enough information to give a complete answer, " - "say so, but still provide any relevant information you did find.") - - return ( - f"{system_prompt}\n\n" - f"## Instruction\n{instruction}\n\n" - "## Context Provided\n" - f"--- START OF CONTEXT ---\n{rag_context}\n--- END OF CONTEXT ---\n\n" - f"## User's Question\n{user_query}\n\n" - "## Your Answer") - else: - return ( - "Here is some new information that might be relevant to your follow-up question. " - "Please synthesise it with our ongoing conversation to provide your answer.\n\n" - "## Additional Context\n" - f"--- START OF CONTEXT ---\n{rag_context}\n--- END OF CONTEXT ---\n\n" - f"## User's Question\n{user_query}\n\n" - "## Your Answer") +def prompt(rag_context, user_query, chat_history): + system_prompt = ( + 'You are a friendly and expert "Uma Musume: Pretty Derby" build guide advisor. ' + 'Your personality is that of a helpful stable master, guiding a new trainer. ' + 'Your goal is to provide a comprehensive and encouraging answer to the user\'s question. ' + 'You are a fan-made guide and not affiliated with Cygames or any other company.' + ) + instruction = ( + "Carefully analyse the user's question, the provided context, and the chat history. " + "Synthesise the information from the context and the chat history to form a coherent answer. " + "Connect different pieces of information and draw logical conclusions. " + "If the context does not contain enough information to give a complete answer, " + "say so, but still provide any relevant information you did find.") + history_string = "\n".join( + [f"{msg['role']}: {msg['content']}" for msg in chat_history]) + + return ( + f"{system_prompt}\n\n" + f"## Instruction\n{instruction}\n\n" + "## Context Provided\n" + f"--- START OF CONTEXT ---\n{rag_context}\n--- END OF CONTEXT ---\n\n" + "## Chat History\n" + f"--- START OF CHAT HISTORY ---\n{history_string}\n--- END OF CHAT HISTORY ---\n\n" + f"## User's Question\n{user_query}\n\n" + "## Your Answer") def start_flask(find_relevant_chunks, query_ollama): @@ -73,17 +68,17 @@ def start_flask(find_relevant_chunks, query_ollama): while True: data = json.loads(webSocket.receive()) user_query = data.get("question", "") - history_context = data.get("history", None) + ollama_history = data.get("history", None) + chat_history = data.get("chat_history", []) top_chunks = find_relevant_chunks(user_query) rag_context = "\n\n".join([chunk[0] for chunk in top_chunks]) - full_prompt = prompt( - rag_context, user_query, is_first_turn=(history_context is None)) + full_prompt = prompt(rag_context, user_query, chat_history) sources = ", ".join( sorted(set(metadata['source'] for _, metadata in top_chunks))) webSocket.send(json.dumps({"type": "sources", "data": sources})) - for ollama_response in query_ollama(full_prompt, history_context): + for ollama_response in query_ollama(full_prompt, ollama_history): webSocket.send(json.dumps(ollama_response)) app.run(host="0.0.0.0", port=5000, debug=False, use_reloader=False) @@ -233,7 +228,8 @@ def main(): "Ready! Ask your Uma Musume build questions (type 'exit' to quit).") logger.info("Web chat available at http://localhost:5000/") - cli_history_context = None + cli_ollama_history = None + cli_chat_history = [] while True: user_query = input("\n> ") @@ -243,22 +239,25 @@ def main(): top_chunks = find_relevant_chunks(user_query) rag_context = "\n\n".join([chunk[0] for chunk in top_chunks]) - full_prompt = prompt( - rag_context, user_query, is_first_turn=(cli_history_context is None)) + + cli_chat_history.append({"role": "user", "content": user_query}) + + full_prompt = prompt(rag_context, user_query, cli_chat_history) print("\n") full_answer = "" - for ollama_response in query_ollama(full_prompt, cli_history_context): + for ollama_response in query_ollama(full_prompt, cli_ollama_history): if ollama_response["type"] == "answer_chunk": chunk = ollama_response["data"] full_answer += chunk print(chunk, end="", flush=True) elif ollama_response["type"] == "history": - cli_history_context = ollama_response["data"] + cli_ollama_history = ollama_response["data"] + cli_chat_history.append({"role": "assistant", "content": full_answer}) print( "\n\nSources:", ", ".join( sorted(set(metadata['source'] for _, metadata in top_chunks)))) diff --git a/src/umapyai/chat.html b/src/umapyai/chat.html index d94491a..b339dab 100644 --- a/src/umapyai/chat.html +++ b/src/umapyai/chat.html @@ -184,7 +184,8 @@ let prompt = document.getElementById("prompt"); let sendButton = document.getElementById("send-button"); let chat = []; - let historyContext = null; + let history = null; + let ragHistory = []; let currentAIMessage = null; const webSocket = new WebSocket(`ws://${window.location.host}/api/ask`); @@ -196,7 +197,9 @@ } else if (response.type === "answer_chunk") { currentAIMessage.text += response.data; } else if (response.type === "history") { - historyContext = response.data; + history = response.data; + } else if (response.type === "rag_context") { + ragHistory.push(response.data); } else if (response.type === "error") { currentAIMessage.text += "\n\nError: " + response.data; } @@ -235,10 +238,21 @@ chat.push(currentAIMessage); render(); + const chatHistoryForPrompt = chat + .map((m) => ({ + role: m.user ? "user" : "assistant", + content: m.text, + })) + .slice(0, -1); + prompt.value = ""; webSocket.send( - JSON.stringify({ question: query, history: historyContext }), + JSON.stringify({ + question: query, + history: history, + chat_history: chatHistoryForPrompt, + }), ); }; |