aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFuwn <[email protected]>2025-07-30 10:14:44 +0200
committerFuwn <[email protected]>2025-07-30 10:14:44 +0200
commit3ef972ca814495211cf81ceb13c85dd0e05f0e4b (patch)
tree0ae1b77d83450230035f3a0c6a084bb34265b40b
parentfeat(umapyai): Swap default model (diff)
downloadumapyai-3ef972ca814495211cf81ceb13c85dd0e05f0e4b.tar.xz
umapyai-3ef972ca814495211cf81ceb13c85dd0e05f0e4b.zip
feat(umapyai): Improve context handling
-rw-r--r--src/umapyai/__init__.py77
-rw-r--r--src/umapyai/chat.html20
2 files changed, 55 insertions, 42 deletions
diff --git a/src/umapyai/__init__.py b/src/umapyai/__init__.py
index a5b3c3d..e17b5ba 100644
--- a/src/umapyai/__init__.py
+++ b/src/umapyai/__init__.py
@@ -30,36 +30,31 @@ socket = Sock(app)
CORS(app)
-def prompt(rag_context, user_query, is_first_turn=True):
- if is_first_turn:
- system_prompt = (
- 'You are a friendly and expert "Uma Musume: Pretty Derby" build guide advisor. '
- 'Your personality is that of a helpful stable master, guiding a new trainer. '
- 'Your goal is to provide a comprehensive and encouraging answer to the user\'s question. '
- 'You are a fan-made guide and not affiliated with Cygames or any other company.'
- )
- instruction = (
- "Carefully analyse the user's question and the provided context. "
- "Synthesise the information from the context to form a coherent answer. "
- "Connect different pieces of information and draw logical conclusions from the text. "
- "If the context does not contain enough information to give a complete answer, "
- "say so, but still provide any relevant information you did find.")
-
- return (
- f"{system_prompt}\n\n"
- f"## Instruction\n{instruction}\n\n"
- "## Context Provided\n"
- f"--- START OF CONTEXT ---\n{rag_context}\n--- END OF CONTEXT ---\n\n"
- f"## User's Question\n{user_query}\n\n"
- "## Your Answer")
- else:
- return (
- "Here is some new information that might be relevant to your follow-up question. "
- "Please synthesise it with our ongoing conversation to provide your answer.\n\n"
- "## Additional Context\n"
- f"--- START OF CONTEXT ---\n{rag_context}\n--- END OF CONTEXT ---\n\n"
- f"## User's Question\n{user_query}\n\n"
- "## Your Answer")
+def prompt(rag_context, user_query, chat_history):
+ system_prompt = (
+ 'You are a friendly and expert "Uma Musume: Pretty Derby" build guide advisor. '
+ 'Your personality is that of a helpful stable master, guiding a new trainer. '
+ 'Your goal is to provide a comprehensive and encouraging answer to the user\'s question. '
+ 'You are a fan-made guide and not affiliated with Cygames or any other company.'
+ )
+ instruction = (
+ "Carefully analyse the user's question, the provided context, and the chat history. "
+ "Synthesise the information from the context and the chat history to form a coherent answer. "
+ "Connect different pieces of information and draw logical conclusions. "
+ "If the context does not contain enough information to give a complete answer, "
+ "say so, but still provide any relevant information you did find.")
+ history_string = "\n".join(
+ [f"{msg['role']}: {msg['content']}" for msg in chat_history])
+
+ return (
+ f"{system_prompt}\n\n"
+ f"## Instruction\n{instruction}\n\n"
+ "## Context Provided\n"
+ f"--- START OF CONTEXT ---\n{rag_context}\n--- END OF CONTEXT ---\n\n"
+ "## Chat History\n"
+ f"--- START OF CHAT HISTORY ---\n{history_string}\n--- END OF CHAT HISTORY ---\n\n"
+ f"## User's Question\n{user_query}\n\n"
+ "## Your Answer")
def start_flask(find_relevant_chunks, query_ollama):
@@ -73,17 +68,17 @@ def start_flask(find_relevant_chunks, query_ollama):
while True:
data = json.loads(webSocket.receive())
user_query = data.get("question", "")
- history_context = data.get("history", None)
+ ollama_history = data.get("history", None)
+ chat_history = data.get("chat_history", [])
top_chunks = find_relevant_chunks(user_query)
rag_context = "\n\n".join([chunk[0] for chunk in top_chunks])
- full_prompt = prompt(
- rag_context, user_query, is_first_turn=(history_context is None))
+ full_prompt = prompt(rag_context, user_query, chat_history)
sources = ", ".join(
sorted(set(metadata['source'] for _, metadata in top_chunks)))
webSocket.send(json.dumps({"type": "sources", "data": sources}))
- for ollama_response in query_ollama(full_prompt, history_context):
+ for ollama_response in query_ollama(full_prompt, ollama_history):
webSocket.send(json.dumps(ollama_response))
app.run(host="0.0.0.0", port=5000, debug=False, use_reloader=False)
@@ -233,7 +228,8 @@ def main():
"Ready! Ask your Uma Musume build questions (type 'exit' to quit).")
logger.info("Web chat available at http://localhost:5000/")
- cli_history_context = None
+ cli_ollama_history = None
+ cli_chat_history = []
while True:
user_query = input("\n> ")
@@ -243,22 +239,25 @@ def main():
top_chunks = find_relevant_chunks(user_query)
rag_context = "\n\n".join([chunk[0] for chunk in top_chunks])
- full_prompt = prompt(
- rag_context, user_query, is_first_turn=(cli_history_context is None))
+
+ cli_chat_history.append({"role": "user", "content": user_query})
+
+ full_prompt = prompt(rag_context, user_query, cli_chat_history)
print("\n")
full_answer = ""
- for ollama_response in query_ollama(full_prompt, cli_history_context):
+ for ollama_response in query_ollama(full_prompt, cli_ollama_history):
if ollama_response["type"] == "answer_chunk":
chunk = ollama_response["data"]
full_answer += chunk
print(chunk, end="", flush=True)
elif ollama_response["type"] == "history":
- cli_history_context = ollama_response["data"]
+ cli_ollama_history = ollama_response["data"]
+ cli_chat_history.append({"role": "assistant", "content": full_answer})
print(
"\n\nSources:", ", ".join(
sorted(set(metadata['source'] for _, metadata in top_chunks))))
diff --git a/src/umapyai/chat.html b/src/umapyai/chat.html
index d94491a..b339dab 100644
--- a/src/umapyai/chat.html
+++ b/src/umapyai/chat.html
@@ -184,7 +184,8 @@
let prompt = document.getElementById("prompt");
let sendButton = document.getElementById("send-button");
let chat = [];
- let historyContext = null;
+ let history = null;
+ let ragHistory = [];
let currentAIMessage = null;
const webSocket = new WebSocket(`ws://${window.location.host}/api/ask`);
@@ -196,7 +197,9 @@
} else if (response.type === "answer_chunk") {
currentAIMessage.text += response.data;
} else if (response.type === "history") {
- historyContext = response.data;
+ history = response.data;
+ } else if (response.type === "rag_context") {
+ ragHistory.push(response.data);
} else if (response.type === "error") {
currentAIMessage.text += "\n\nError: " + response.data;
}
@@ -235,10 +238,21 @@
chat.push(currentAIMessage);
render();
+ const chatHistoryForPrompt = chat
+ .map((m) => ({
+ role: m.user ? "user" : "assistant",
+ content: m.text,
+ }))
+ .slice(0, -1);
+
prompt.value = "";
webSocket.send(
- JSON.stringify({ question: query, history: historyContext }),
+ JSON.stringify({
+ question: query,
+ history: history,
+ chat_history: chatHistoryForPrompt,
+ }),
);
};