aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorFuwn <[email protected]>2025-08-06 23:07:24 +0200
committerFuwn <[email protected]>2025-08-06 23:07:24 +0200
commitd620731d81586b4298976bd45e088d478aec0d3b (patch)
treec6b8ea00ebc23a18a239d9a7a216255963835080 /src
parentdocs: Add umapyoi.net API documentation (diff)
downloadumapyai-d620731d81586b4298976bd45e088d478aec0d3b.tar.xz
umapyai-d620731d81586b4298976bd45e088d478aec0d3b.zip
feat(umapyai): Add tools to model context
Diffstat (limited to 'src')
-rw-r--r--src/umapyai/__init__.py55
-rw-r--r--src/umapyai/tools.py295
2 files changed, 342 insertions, 8 deletions
diff --git a/src/umapyai/__init__.py b/src/umapyai/__init__.py
index b157516..8618862 100644
--- a/src/umapyai/__init__.py
+++ b/src/umapyai/__init__.py
@@ -14,6 +14,7 @@ from .constants import (ARTICLES_DIRECTORY, CHROMA_DIRECTORY, CHROMA_COLLECTION,
from .ollama_server import start_ollama_server, is_ollama_live, ensure_model_pulled, kill_ollama
from collections import defaultdict
from .language import clean_for_match, get_query_phrases
+from .tools import tools, call_tool
logger.remove()
logger.add(
@@ -34,7 +35,9 @@ def prompt(rag_context, user_query, chat_history):
'You are a friendly and expert "Uma Musume: Pretty Derby" build guide advisor. '
'Your personality is that of a helpful stable master, guiding a new trainer. You are not anyone in particular. '
'Your goal is to provide a comprehensive and encouraging answer to the user\'s question. '
- 'You are a fan-made guide and are not affiliated with Cygames or any other entity.'
+ 'You are a fan-made guide and are not affiliated with Cygames or any other entity.\n\n'
+ 'IMPORTANT: If the user asks a question that could be answered by a tool (API), ALWAYS call the tool, even if you believe you know the answer from context or memory. '
+ 'NEVER rely solely on the provided documentation or context for character data, builds, live news, or support card info—always call the relevant tool endpoint if one exists.'
)
instruction = (
"Carefully analyse the user's question, the provided context, and the chat history. "
@@ -185,13 +188,49 @@ def main():
def query_ollama(prompt, context=None):
try:
- for chunk in ollama.generate(
- model=OLLAMA_MODEL, prompt=prompt, stream=True, context=context):
- if not chunk.get("done"):
- yield {"type": "answer_chunk", "data": chunk.get("response", "")}
- else:
- yield {"type": "history", "data": chunk.get("context")}
-
+ messages = [{"role": "user", "content": prompt}]
+
+ if context:
+ messages = context + messages
+
+ while True:
+ tool_called = False
+
+ for chunk in ollama.chat(
+ model=OLLAMA_MODEL, messages=messages, stream=True, tools=tools):
+ message = chunk.get("message", {})
+
+ if "tool_calls" in message:
+ for tool_call in message["tool_calls"]:
+ tool_result = call_tool(tool_call)
+
+ messages.append({
+ "role": "assistant",
+ "tool_call_id": tool_call.get("id", "tool_call_1"),
+ "name": tool_call.get("name", "unknown_tool"),
+ "content": None,
+ "tool_calls": [tool_call],
+ })
+ messages.append({
+ "role": "tool",
+ "tool_call_id": tool_call.get("id", "tool_call_1"),
+ "content": str(tool_result),
+ })
+
+ tool_called = True
+
+ break
+ elif "content" in message and message["content"]:
+ yield {"type": "answer_chunk", "data": message["content"]}
+
+ if chunk.get("done"):
+ yield {
+ "type": "history",
+ "data": messages + ([message] if message else [])
+ }
+
+ if not tool_called:
+ break
except Exception as error:
error_message = f"Error communicating with Ollama: {error}"
diff --git a/src/umapyai/tools.py b/src/umapyai/tools.py
new file mode 100644
index 0000000..36ed85b
--- /dev/null
+++ b/src/umapyai/tools.py
@@ -0,0 +1,295 @@
+import json
+import requests
+import inspect
+
+
+def get_character_info(id: int):
+ "Gets a character's information by their ID."
+
+ return requests.get(f"https://umapyoi.net/api/v1/character/{id}").json()
+
+
+def get_character_list():
+ "Gets a list of all characters."
+
+ return requests.get("https://umapyoi.net/api/v1/character/list").json()
+
+
+def get_current_gacha_banners():
+ "Gets the current gacha banners."
+
+ return requests.get("https://umapyoi.net/api/v1/gacha/current").json()
+
+
+def get_latest_news(count: int):
+ "Gets the latest news."
+
+ return requests.get(f"https://umapyoi.net/api/v1/news/latest/{count}").json()
+
+
+def get_support_card_info(id: int):
+ "Gets a support card's information by its ID."
+
+ return requests.get(f"https://umapyoi.net/api/v1/support/{id}").json()
+
+
+def get_support_card_list():
+ "Gets a list of all support cards."
+
+ return requests.get("https://umapyoi.net/api/v1/support").json()
+
+
+def get_music_list():
+ "Gets a list of all music."
+
+ return requests.get("https://umapyoi.net/api/v1/music/min/albums").json()
+
+
+def get_voice_actor_info(id: int):
+ "Gets a voice actor's information by their ID."
+
+ return requests.get(f"https://umapyoi.net/api/v1/va/{id}").json()
+
+
+def get_outfit_info(id: int):
+ "Gets an outfit's information by its ID."
+
+ return requests.get(f"https://umapyoi.net/api/v1/outfit/{id}").json()
+
+
+def get_vpn_list():
+ "Gets a list of all VPNs."
+
+ return requests.get("https://umapyoi.net/api/v1/vpn/all").json()
+
+
+def get_gacha_info(id: int):
+ "Gets a gacha banner's information by its ID."
+
+ return requests.get(f"https://umapyoi.net/api/v1/gacha/{id}").json()
+
+
+def get_news_info(id: int):
+ "Gets a news post's information by its ID."
+
+ return requests.get(f"https://umapyoi.net/api/v1/news/{id}").json()
+
+
+def get_music_info(id: int):
+ "Gets a music album's information by its ID."
+
+ return requests.get(f"https://umapyoi.net/api/v1/music/album/{id}").json()
+
+
+tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_character_info",
+ "description": "Gets a character's information by their ID.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "id": {
+ "type": "integer",
+ "description": "The character's ID."
+ }
+ },
+ "required": ["id"],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_character_list",
+ "description": "Gets a list of all characters.",
+ "parameters": {},
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_current_gacha_banners",
+ "description": "Gets the current gacha banners.",
+ "parameters": {},
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_latest_news",
+ "description": "Gets the latest news.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "count": {
+ "type": "integer",
+ "description": "The number of news to get.",
+ }
+ },
+ "required": ["count"],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_support_card_info",
+ "description": "Gets a support card's information by its ID.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "id": {
+ "type": "integer",
+ "description": "The support card's ID."
+ }
+ },
+ "required": ["id"],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_support_card_list",
+ "description": "Gets a list of all support cards.",
+ "parameters": {},
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_music_list",
+ "description": "Gets a list of all music.",
+ "parameters": {},
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_voice_actor_info",
+ "description": "Gets a voice actor's information by their ID.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "id": {
+ "type": "integer",
+ "description": "The voice actor's ID."
+ }
+ },
+ "required": ["id"],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_outfit_info",
+ "description": "Gets an outfit's information by its ID.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "id": {
+ "type": "integer",
+ "description": "The outfit's ID."
+ }
+ },
+ "required": ["id"],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_vpn_list",
+ "description": "Gets a list of all VPNs.",
+ "parameters": {},
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_gacha_info",
+ "description": "Gets a gacha banner's information by its ID.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "id": {
+ "type": "integer",
+ "description": "The gacha banner's ID."
+ }
+ },
+ "required": ["id"],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_news_info",
+ "description": "Gets a news post's information by its ID.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "id": {
+ "type": "integer",
+ "description": "The news post's ID."
+ }
+ },
+ "required": ["id"],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_music_info",
+ "description": "Gets a music album's information by its ID.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "id": {
+ "type": "integer",
+ "description": "The music album's ID."
+ }
+ },
+ "required": ["id"],
+ },
+ },
+ },
+]
+
+API_FUNCTIONS = {
+ "get_character_info": get_character_info,
+ "get_character_list": get_character_list,
+ "get_current_gacha_banners": get_current_gacha_banners,
+ "get_latest_news": get_latest_news,
+ "get_support_card_info": get_support_card_info,
+ "get_support_card_list": get_support_card_list,
+ "get_music_list": get_music_list,
+ "get_voice_actor_info": get_voice_actor_info,
+ "get_outfit_info": get_outfit_info,
+ "get_vpn_list": get_vpn_list,
+ "get_gacha_info": get_gacha_info,
+ "get_news_info": get_news_info,
+ "get_music_info": get_music_info,
+}
+
+
+def call_tool(tool_call):
+ name = tool_call["function"]["name"]
+ arguments = tool_call["function"]["arguments"]
+
+ if isinstance(arguments, str):
+ arguments = json.loads(arguments)
+
+ if name in API_FUNCTIONS:
+ function = API_FUNCTIONS[name]
+
+ if inspect.signature(function).parameters:
+ return function(**arguments)
+ else:
+ return function()
+
+ return {"error": "Unknown function"}