aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFuwn <[email protected]>2025-07-27 20:43:33 +0200
committerFuwn <[email protected]>2025-07-27 20:43:33 +0200
commitb8e1e39463c61267aef680670fb62c421e4a27b5 (patch)
tree32552ba266cd0c5c3246ceea0b86e59b43d82c0d
parentfeat: Pretty logging (diff)
downloadumapyai-b8e1e39463c61267aef680670fb62c421e4a27b5.tar.xz
umapyai-b8e1e39463c61267aef680670fb62c421e4a27b5.zip
refactor: Move Ollama specific functions to module
-rw-r--r--src/umapyai/__init__.py65
-rw-r--r--src/umapyai/ollama.py69
2 files changed, 70 insertions, 64 deletions
diff --git a/src/umapyai/__init__.py b/src/umapyai/__init__.py
index 021807c..feca194 100644
--- a/src/umapyai/__init__.py
+++ b/src/umapyai/__init__.py
@@ -1,15 +1,13 @@
import os
import sys
-import time
import chromadb
from sentence_transformers import SentenceTransformer
import requests
-import subprocess
-import psutil
from loguru import logger
from .constants import (ARTICLES_DIRECTORY, CHROMA_DIRECTORY, CHROMA_COLLECTION,
CHUNK_SIZE, EMBEDDING_MODEL, OLLAMA_MODEL, TOP_K,
OLLAMA_URL)
+from .ollama import start_ollama_server, is_ollama_live, ensure_model_pulled, kill_ollama
logger.remove()
logger.add(
@@ -20,67 +18,6 @@ logger.add(
)
-def is_ollama_live():
- try:
- response = requests.get(f"{OLLAMA_URL}/api/tags", timeout=2)
-
- return response.status_code == 200
- except Exception:
- return False
-
-
-def start_ollama_server():
- logger.info("Starting Ollama server with OLLAMA_ORIGINS='*' ...")
-
- environment = os.environ.copy()
- environment["OLLAMA_ORIGINS"] = "*"
- process = subprocess.Popen(["ollama", "serve"],
- env=environment,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- text=True)
-
- for _ in range(30):
- if is_ollama_live():
- logger.success("Ollama is now live.")
-
- return process
-
- time.sleep(1)
-
- logger.error("Ollama server did not start after 30 seconds.")
- process.terminate()
- sys.exit(1)
-
-
-def kill_ollama(process):
- logger.info("Killing Ollama ...")
-
- try:
- parent_process = psutil.Process(process.pid)
-
- for child_process in parent_process.children(recursive=True):
- child_process.terminate()
-
- parent_process.terminate()
- except Exception as error:
- logger.error(f"Error killing Ollama: {error}")
-
-
-def ensure_model_pulled(model):
- try:
- tags = requests.get(f"{OLLAMA_URL}/api/tags").json().get("models", [])
-
- if not any(model in m.get("name", "") for m in tags):
- logger.info(f"Pulling model '{model}' ...")
- subprocess.run(["ollama", "pull", model], check=True)
- else:
- logger.success(f"Model '{model}' already pulled.")
- except Exception as e:
- logger.warning(f"Couldn't check/pull Ollama model: {e}")
- logger.warning("Proceeding anyway ...")
-
-
def main():
ollama_process = None
started_ollama = False
diff --git a/src/umapyai/ollama.py b/src/umapyai/ollama.py
new file mode 100644
index 0000000..73329be
--- /dev/null
+++ b/src/umapyai/ollama.py
@@ -0,0 +1,69 @@
+import requests
+import time
+import subprocess
+import psutil
+from .constants import OLLAMA_URL
+import os
+from loguru import logger
+import sys
+
+
+def is_ollama_live():
+ try:
+ response = requests.get(f"{OLLAMA_URL}/api/tags", timeout=2)
+
+ return response.status_code == 200
+ except Exception:
+ return False
+
+
+def start_ollama_server():
+ logger.info("Starting Ollama server with OLLAMA_ORIGINS='*' ...")
+
+ environment = os.environ.copy()
+ environment["OLLAMA_ORIGINS"] = "*"
+ process = subprocess.Popen(["ollama", "serve"],
+ env=environment,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ text=True)
+
+ for _ in range(30):
+ if is_ollama_live():
+ logger.success("Ollama is now live.")
+
+ return process
+
+ time.sleep(1)
+
+ logger.error("Ollama server did not start after 30 seconds.")
+ process.terminate()
+ sys.exit(1)
+
+
+def kill_ollama(process):
+ logger.info("Killing Ollama ...")
+
+ try:
+ parent_process = psutil.Process(process.pid)
+
+ for child_process in parent_process.children(recursive=True):
+ child_process.terminate()
+
+ parent_process.terminate()
+ except Exception as error:
+ logger.error(f"Error killing Ollama: {error}")
+
+
+def ensure_model_pulled(model):
+ try:
+ tags = requests.get(f"{OLLAMA_URL}/api/tags").json().get("models", [])
+
+ if not any(model in m.get("name", "") for m in tags):
+ logger.info(f"Pulling model '{model}' ...")
+ subprocess.run(["ollama", "pull", model], check=True)
+ else:
+ logger.success(f"Model '{model}' already pulled.")
+ except Exception as e:
+ logger.warning(f"Couldn't check/pull Ollama model: {e}")
+ logger.warning("Proceeding anyway ...")