From d620731d81586b4298976bd45e088d478aec0d3b Mon Sep 17 00:00:00 2001 From: Fuwn Date: Wed, 6 Aug 2025 23:07:24 +0200 Subject: feat(umapyai): Add tools to model context --- src/umapyai/__init__.py | 55 +++++++-- src/umapyai/tools.py | 295 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 342 insertions(+), 8 deletions(-) create mode 100644 src/umapyai/tools.py (limited to 'src') diff --git a/src/umapyai/__init__.py b/src/umapyai/__init__.py index b157516..8618862 100644 --- a/src/umapyai/__init__.py +++ b/src/umapyai/__init__.py @@ -14,6 +14,7 @@ from .constants import (ARTICLES_DIRECTORY, CHROMA_DIRECTORY, CHROMA_COLLECTION, from .ollama_server import start_ollama_server, is_ollama_live, ensure_model_pulled, kill_ollama from collections import defaultdict from .language import clean_for_match, get_query_phrases +from .tools import tools, call_tool logger.remove() logger.add( @@ -34,7 +35,9 @@ def prompt(rag_context, user_query, chat_history): 'You are a friendly and expert "Uma Musume: Pretty Derby" build guide advisor. ' 'Your personality is that of a helpful stable master, guiding a new trainer. You are not anyone in particular. ' 'Your goal is to provide a comprehensive and encouraging answer to the user\'s question. ' - 'You are a fan-made guide and are not affiliated with Cygames or any other entity.' + 'You are a fan-made guide and are not affiliated with Cygames or any other entity.\n\n' + 'IMPORTANT: If the user asks a question that could be answered by a tool (API), ALWAYS call the tool, even if you believe you know the answer from context or memory. ' + 'NEVER rely solely on the provided documentation or context for character data, builds, live news, or support card info—always call the relevant tool endpoint if one exists.' ) instruction = ( "Carefully analyse the user's question, the provided context, and the chat history. " @@ -185,13 +188,49 @@ def main(): def query_ollama(prompt, context=None): try: - for chunk in ollama.generate( - model=OLLAMA_MODEL, prompt=prompt, stream=True, context=context): - if not chunk.get("done"): - yield {"type": "answer_chunk", "data": chunk.get("response", "")} - else: - yield {"type": "history", "data": chunk.get("context")} - + messages = [{"role": "user", "content": prompt}] + + if context: + messages = context + messages + + while True: + tool_called = False + + for chunk in ollama.chat( + model=OLLAMA_MODEL, messages=messages, stream=True, tools=tools): + message = chunk.get("message", {}) + + if "tool_calls" in message: + for tool_call in message["tool_calls"]: + tool_result = call_tool(tool_call) + + messages.append({ + "role": "assistant", + "tool_call_id": tool_call.get("id", "tool_call_1"), + "name": tool_call.get("name", "unknown_tool"), + "content": None, + "tool_calls": [tool_call], + }) + messages.append({ + "role": "tool", + "tool_call_id": tool_call.get("id", "tool_call_1"), + "content": str(tool_result), + }) + + tool_called = True + + break + elif "content" in message and message["content"]: + yield {"type": "answer_chunk", "data": message["content"]} + + if chunk.get("done"): + yield { + "type": "history", + "data": messages + ([message] if message else []) + } + + if not tool_called: + break except Exception as error: error_message = f"Error communicating with Ollama: {error}" diff --git a/src/umapyai/tools.py b/src/umapyai/tools.py new file mode 100644 index 0000000..36ed85b --- /dev/null +++ b/src/umapyai/tools.py @@ -0,0 +1,295 @@ +import json +import requests +import inspect + + +def get_character_info(id: int): + "Gets a character's information by their ID." + + return requests.get(f"https://umapyoi.net/api/v1/character/{id}").json() + + +def get_character_list(): + "Gets a list of all characters." + + return requests.get("https://umapyoi.net/api/v1/character/list").json() + + +def get_current_gacha_banners(): + "Gets the current gacha banners." + + return requests.get("https://umapyoi.net/api/v1/gacha/current").json() + + +def get_latest_news(count: int): + "Gets the latest news." + + return requests.get(f"https://umapyoi.net/api/v1/news/latest/{count}").json() + + +def get_support_card_info(id: int): + "Gets a support card's information by its ID." + + return requests.get(f"https://umapyoi.net/api/v1/support/{id}").json() + + +def get_support_card_list(): + "Gets a list of all support cards." + + return requests.get("https://umapyoi.net/api/v1/support").json() + + +def get_music_list(): + "Gets a list of all music." + + return requests.get("https://umapyoi.net/api/v1/music/min/albums").json() + + +def get_voice_actor_info(id: int): + "Gets a voice actor's information by their ID." + + return requests.get(f"https://umapyoi.net/api/v1/va/{id}").json() + + +def get_outfit_info(id: int): + "Gets an outfit's information by its ID." + + return requests.get(f"https://umapyoi.net/api/v1/outfit/{id}").json() + + +def get_vpn_list(): + "Gets a list of all VPNs." + + return requests.get("https://umapyoi.net/api/v1/vpn/all").json() + + +def get_gacha_info(id: int): + "Gets a gacha banner's information by its ID." + + return requests.get(f"https://umapyoi.net/api/v1/gacha/{id}").json() + + +def get_news_info(id: int): + "Gets a news post's information by its ID." + + return requests.get(f"https://umapyoi.net/api/v1/news/{id}").json() + + +def get_music_info(id: int): + "Gets a music album's information by its ID." + + return requests.get(f"https://umapyoi.net/api/v1/music/album/{id}").json() + + +tools = [ + { + "type": "function", + "function": { + "name": "get_character_info", + "description": "Gets a character's information by their ID.", + "parameters": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "description": "The character's ID." + } + }, + "required": ["id"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_character_list", + "description": "Gets a list of all characters.", + "parameters": {}, + }, + }, + { + "type": "function", + "function": { + "name": "get_current_gacha_banners", + "description": "Gets the current gacha banners.", + "parameters": {}, + }, + }, + { + "type": "function", + "function": { + "name": "get_latest_news", + "description": "Gets the latest news.", + "parameters": { + "type": "object", + "properties": { + "count": { + "type": "integer", + "description": "The number of news to get.", + } + }, + "required": ["count"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_support_card_info", + "description": "Gets a support card's information by its ID.", + "parameters": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "description": "The support card's ID." + } + }, + "required": ["id"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_support_card_list", + "description": "Gets a list of all support cards.", + "parameters": {}, + }, + }, + { + "type": "function", + "function": { + "name": "get_music_list", + "description": "Gets a list of all music.", + "parameters": {}, + }, + }, + { + "type": "function", + "function": { + "name": "get_voice_actor_info", + "description": "Gets a voice actor's information by their ID.", + "parameters": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "description": "The voice actor's ID." + } + }, + "required": ["id"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_outfit_info", + "description": "Gets an outfit's information by its ID.", + "parameters": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "description": "The outfit's ID." + } + }, + "required": ["id"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_vpn_list", + "description": "Gets a list of all VPNs.", + "parameters": {}, + }, + }, + { + "type": "function", + "function": { + "name": "get_gacha_info", + "description": "Gets a gacha banner's information by its ID.", + "parameters": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "description": "The gacha banner's ID." + } + }, + "required": ["id"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_news_info", + "description": "Gets a news post's information by its ID.", + "parameters": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "description": "The news post's ID." + } + }, + "required": ["id"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_music_info", + "description": "Gets a music album's information by its ID.", + "parameters": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "description": "The music album's ID." + } + }, + "required": ["id"], + }, + }, + }, +] + +API_FUNCTIONS = { + "get_character_info": get_character_info, + "get_character_list": get_character_list, + "get_current_gacha_banners": get_current_gacha_banners, + "get_latest_news": get_latest_news, + "get_support_card_info": get_support_card_info, + "get_support_card_list": get_support_card_list, + "get_music_list": get_music_list, + "get_voice_actor_info": get_voice_actor_info, + "get_outfit_info": get_outfit_info, + "get_vpn_list": get_vpn_list, + "get_gacha_info": get_gacha_info, + "get_news_info": get_news_info, + "get_music_info": get_music_info, +} + + +def call_tool(tool_call): + name = tool_call["function"]["name"] + arguments = tool_call["function"]["arguments"] + + if isinstance(arguments, str): + arguments = json.loads(arguments) + + if name in API_FUNCTIONS: + function = API_FUNCTIONS[name] + + if inspect.signature(function).parameters: + return function(**arguments) + else: + return function() + + return {"error": "Unknown function"} -- cgit v1.2.3