aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorFuwn <[email protected]>2025-07-27 20:26:02 +0200
committerFuwn <[email protected]>2025-07-27 20:26:02 +0200
commit3f3f6c56981116e7982461b86025da0e278bf8d4 (patch)
treea666f0ae6d609e20e9a34408d55d465893fc383a /src
downloadumapyai-3f3f6c56981116e7982461b86025da0e278bf8d4.tar.xz
umapyai-3f3f6c56981116e7982461b86025da0e278bf8d4.zip
feat: Initial commit
Diffstat (limited to 'src')
-rw-r--r--src/umapyai/__init__.py187
-rw-r--r--src/umapyai/__main__.py4
-rw-r--r--src/umapyai/constants.py8
3 files changed, 199 insertions, 0 deletions
diff --git a/src/umapyai/__init__.py b/src/umapyai/__init__.py
new file mode 100644
index 0000000..4898488
--- /dev/null
+++ b/src/umapyai/__init__.py
@@ -0,0 +1,187 @@
+import os
+import sys
+import time
+import chromadb
+from sentence_transformers import SentenceTransformer
+import requests
+import subprocess
+import psutil
+from .constants import ARTICLES_DIRECTORY, CHROMA_DIRECTORY, CHROMA_COLLECTION, CHUNK_SIZE, EMBEDDING_MODEL, OLLAMA_MODEL, TOP_K, OLLAMA_URL
+
+
+def is_ollama_live():
+ try:
+ response = requests.get(f"{OLLAMA_URL}/api/tags", timeout=2)
+
+ return response.status_code == 200
+ except Exception:
+ return False
+
+
+def start_ollama_server():
+ print("Starting Ollama server with OLLAMA_ORIGINS='*' ...")
+
+ environment = os.environ.copy()
+ environment["OLLAMA_ORIGINS"] = "*"
+ process = subprocess.Popen(["ollama", "serve"],
+ env=environment,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ text=True)
+
+ for _ in range(30):
+ if is_ollama_live():
+ print("Ollama is now live.")
+
+ return process
+
+ time.sleep(1)
+
+ print("ERROR: Ollama server did not start after 30 seconds.")
+ process.terminate()
+ sys.exit(1)
+
+
+def kill_ollama(process):
+ print("Killing Ollama ...")
+
+ try:
+ parent_process = psutil.Process(process.pid)
+
+ for child_process in parent_process.children(recursive=True):
+ child_process.terminate()
+
+ parent_process.terminate()
+ except Exception as error:
+ print("Error killing Ollama:", error)
+
+
+def ensure_model_pulled(model):
+ try:
+ tags = requests.get(f"{OLLAMA_URL}/api/tags").json().get("models", [])
+
+ if not any(model in m.get("name", "") for m in tags):
+ print(f"Pulling model '{model}' ...")
+ subprocess.run(["ollama", "pull", model], check=True)
+ else:
+ print(f"Model '{model}' already pulled.")
+ except Exception as e:
+ print("Couldn't check/pull Ollama model:", e)
+ print("Proceeding anyway ...")
+
+
+def main():
+ ollama_process = None
+ started_ollama = False
+
+ try:
+ if not is_ollama_live():
+ ollama_process = start_ollama_server()
+ started_ollama = True
+ else:
+ print("Ollama is already running.")
+
+ ensure_model_pulled(OLLAMA_MODEL)
+
+ print("Chunking articles ...")
+
+ chunks = []
+
+ for file_name in os.listdir(ARTICLES_DIRECTORY):
+ if not file_name.endswith(".txt"):
+ continue
+
+ with open(
+ os.path.join(ARTICLES_DIRECTORY, file_name),
+ encoding="utf-8") as file:
+ words = file.read().split()
+
+ for i in range(0, len(words), CHUNK_SIZE):
+ chunk = " ".join(words[i:i + CHUNK_SIZE])
+
+ if chunk.strip():
+ chunks.append({"source": file_name, "chunk": chunk})
+
+ print(f"Total chunks: {len(chunks)}")
+ print("Generating embeddings ...")
+
+ model = SentenceTransformer(EMBEDDING_MODEL)
+
+ for chunk in chunks:
+ chunk["embedding"] = model.encode(chunk["chunk"])
+
+ print("Storing embeddings in ChromaDB ...")
+
+ chroma_client = chromadb.PersistentClient(path=CHROMA_DIRECTORY)
+
+ if CHROMA_COLLECTION in [
+ collection.name for collection in chroma_client.list_collections()
+ ]:
+ print("Collection exists, deleting and recreating for fresh import ...")
+
+ chroma_client.delete_collection(CHROMA_COLLECTION)
+
+ collection = chroma_client.get_or_create_collection(CHROMA_COLLECTION)
+
+ for i, chunk in enumerate(chunks):
+ collection.add(
+ ids=[str(i)],
+ documents=[chunk["chunk"]],
+ embeddings=[chunk["embedding"].tolist()],
+ metadatas=[{
+ "source": chunk["source"]
+ }])
+
+ def find_relevant_chunks(query, top_k=TOP_K):
+ q_embed = model.encode(query)
+ results = collection.query(
+ query_embeddings=[q_embed.tolist()], n_results=top_k)
+ documents = results['documents'][0]
+ metadatas = results['metadatas'][0]
+
+ return [(document, metadata)
+ for document, metadata in zip(documents, metadatas)]
+
+ def query_ollama(prompt):
+ url = f"{OLLAMA_URL}/api/generate"
+ payload = {
+ "model": OLLAMA_MODEL,
+ "prompt": prompt,
+ "stream": False,
+ }
+
+ try:
+ response = requests.post(url, json=payload)
+
+ response.raise_for_status()
+
+ return response.json().get('response', '').strip()
+ except Exception as error:
+ return f"Error communicating with Ollama: {error}"
+
+ print("\nReady! Ask your Uma Musume build questions (type 'exit' to quit):")
+
+ while True:
+ user_query = input("\n> ")
+
+ if user_query.strip().lower() == "exit":
+ break
+
+ top_chunks = find_relevant_chunks(user_query)
+ context = "\n\n".join([c[0] for c in top_chunks])
+ full_prompt = (
+ "You are an expert Uma Musume: Pretty Derby build guide advisor.\n"
+ "Answer the user's question using ONLY the following context. "
+ "If the answer isn't in the context, say you don't know.\n\n"
+ f"Context:\n{context}\n\n"
+ f"Question: {user_query}\nAnswer:")
+ answer = query_ollama(full_prompt)
+
+ print("\n", answer)
+ print(
+ "\nSources:", ", ".join(
+ sorted(set(metadata['source'] for _, metadata in top_chunks))))
+ finally:
+ if started_ollama and ollama_process is not None:
+ kill_ollama(ollama_process)
+ print("Ollama server stopped.")
diff --git a/src/umapyai/__main__.py b/src/umapyai/__main__.py
new file mode 100644
index 0000000..851db54
--- /dev/null
+++ b/src/umapyai/__main__.py
@@ -0,0 +1,4 @@
+import umapyai
+import sys
+
+sys.exit(umapyai.main())
diff --git a/src/umapyai/constants.py b/src/umapyai/constants.py
new file mode 100644
index 0000000..9add8ce
--- /dev/null
+++ b/src/umapyai/constants.py
@@ -0,0 +1,8 @@
+ARTICLES_DIRECTORY = "uma_articles_clean"
+CHROMA_DIRECTORY = "./chromadb"
+CHROMA_COLLECTION = "uma_guides"
+CHUNK_SIZE = 350 # words
+EMBEDDING_MODEL = "all-MiniLM-L6-v2"
+OLLAMA_MODEL = "llama3.2"
+TOP_K = 4
+OLLAMA_URL = "http://localhost:11434"