diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/umapyai/__init__.py | 187 | ||||
| -rw-r--r-- | src/umapyai/__main__.py | 4 | ||||
| -rw-r--r-- | src/umapyai/constants.py | 8 |
3 files changed, 199 insertions, 0 deletions
diff --git a/src/umapyai/__init__.py b/src/umapyai/__init__.py new file mode 100644 index 0000000..4898488 --- /dev/null +++ b/src/umapyai/__init__.py @@ -0,0 +1,187 @@ +import os +import sys +import time +import chromadb +from sentence_transformers import SentenceTransformer +import requests +import subprocess +import psutil +from .constants import ARTICLES_DIRECTORY, CHROMA_DIRECTORY, CHROMA_COLLECTION, CHUNK_SIZE, EMBEDDING_MODEL, OLLAMA_MODEL, TOP_K, OLLAMA_URL + + +def is_ollama_live(): + try: + response = requests.get(f"{OLLAMA_URL}/api/tags", timeout=2) + + return response.status_code == 200 + except Exception: + return False + + +def start_ollama_server(): + print("Starting Ollama server with OLLAMA_ORIGINS='*' ...") + + environment = os.environ.copy() + environment["OLLAMA_ORIGINS"] = "*" + process = subprocess.Popen(["ollama", "serve"], + env=environment, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True) + + for _ in range(30): + if is_ollama_live(): + print("Ollama is now live.") + + return process + + time.sleep(1) + + print("ERROR: Ollama server did not start after 30 seconds.") + process.terminate() + sys.exit(1) + + +def kill_ollama(process): + print("Killing Ollama ...") + + try: + parent_process = psutil.Process(process.pid) + + for child_process in parent_process.children(recursive=True): + child_process.terminate() + + parent_process.terminate() + except Exception as error: + print("Error killing Ollama:", error) + + +def ensure_model_pulled(model): + try: + tags = requests.get(f"{OLLAMA_URL}/api/tags").json().get("models", []) + + if not any(model in m.get("name", "") for m in tags): + print(f"Pulling model '{model}' ...") + subprocess.run(["ollama", "pull", model], check=True) + else: + print(f"Model '{model}' already pulled.") + except Exception as e: + print("Couldn't check/pull Ollama model:", e) + print("Proceeding anyway ...") + + +def main(): + ollama_process = None + started_ollama = False + + try: + if not is_ollama_live(): + ollama_process = start_ollama_server() + started_ollama = True + else: + print("Ollama is already running.") + + ensure_model_pulled(OLLAMA_MODEL) + + print("Chunking articles ...") + + chunks = [] + + for file_name in os.listdir(ARTICLES_DIRECTORY): + if not file_name.endswith(".txt"): + continue + + with open( + os.path.join(ARTICLES_DIRECTORY, file_name), + encoding="utf-8") as file: + words = file.read().split() + + for i in range(0, len(words), CHUNK_SIZE): + chunk = " ".join(words[i:i + CHUNK_SIZE]) + + if chunk.strip(): + chunks.append({"source": file_name, "chunk": chunk}) + + print(f"Total chunks: {len(chunks)}") + print("Generating embeddings ...") + + model = SentenceTransformer(EMBEDDING_MODEL) + + for chunk in chunks: + chunk["embedding"] = model.encode(chunk["chunk"]) + + print("Storing embeddings in ChromaDB ...") + + chroma_client = chromadb.PersistentClient(path=CHROMA_DIRECTORY) + + if CHROMA_COLLECTION in [ + collection.name for collection in chroma_client.list_collections() + ]: + print("Collection exists, deleting and recreating for fresh import ...") + + chroma_client.delete_collection(CHROMA_COLLECTION) + + collection = chroma_client.get_or_create_collection(CHROMA_COLLECTION) + + for i, chunk in enumerate(chunks): + collection.add( + ids=[str(i)], + documents=[chunk["chunk"]], + embeddings=[chunk["embedding"].tolist()], + metadatas=[{ + "source": chunk["source"] + }]) + + def find_relevant_chunks(query, top_k=TOP_K): + q_embed = model.encode(query) + results = collection.query( + query_embeddings=[q_embed.tolist()], n_results=top_k) + documents = results['documents'][0] + metadatas = results['metadatas'][0] + + return [(document, metadata) + for document, metadata in zip(documents, metadatas)] + + def query_ollama(prompt): + url = f"{OLLAMA_URL}/api/generate" + payload = { + "model": OLLAMA_MODEL, + "prompt": prompt, + "stream": False, + } + + try: + response = requests.post(url, json=payload) + + response.raise_for_status() + + return response.json().get('response', '').strip() + except Exception as error: + return f"Error communicating with Ollama: {error}" + + print("\nReady! Ask your Uma Musume build questions (type 'exit' to quit):") + + while True: + user_query = input("\n> ") + + if user_query.strip().lower() == "exit": + break + + top_chunks = find_relevant_chunks(user_query) + context = "\n\n".join([c[0] for c in top_chunks]) + full_prompt = ( + "You are an expert Uma Musume: Pretty Derby build guide advisor.\n" + "Answer the user's question using ONLY the following context. " + "If the answer isn't in the context, say you don't know.\n\n" + f"Context:\n{context}\n\n" + f"Question: {user_query}\nAnswer:") + answer = query_ollama(full_prompt) + + print("\n", answer) + print( + "\nSources:", ", ".join( + sorted(set(metadata['source'] for _, metadata in top_chunks)))) + finally: + if started_ollama and ollama_process is not None: + kill_ollama(ollama_process) + print("Ollama server stopped.") diff --git a/src/umapyai/__main__.py b/src/umapyai/__main__.py new file mode 100644 index 0000000..851db54 --- /dev/null +++ b/src/umapyai/__main__.py @@ -0,0 +1,4 @@ +import umapyai +import sys + +sys.exit(umapyai.main()) diff --git a/src/umapyai/constants.py b/src/umapyai/constants.py new file mode 100644 index 0000000..9add8ce --- /dev/null +++ b/src/umapyai/constants.py @@ -0,0 +1,8 @@ +ARTICLES_DIRECTORY = "uma_articles_clean" +CHROMA_DIRECTORY = "./chromadb" +CHROMA_COLLECTION = "uma_guides" +CHUNK_SIZE = 350 # words +EMBEDDING_MODEL = "all-MiniLM-L6-v2" +OLLAMA_MODEL = "llama3.2" +TOP_K = 4 +OLLAMA_URL = "http://localhost:11434" |