aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFuwn <[email protected]>2025-07-29 13:25:56 +0200
committerFuwn <[email protected]>2025-07-29 13:25:56 +0200
commit4753f8d5198f76398c9bf24e2c8952a6cf0e403f (patch)
treeaba36833b661691edc76cf26357e3aed03d971c8
parentfeat(umapyai): Improve chunk search (diff)
downloadumapyai-4753f8d5198f76398c9bf24e2c8952a6cf0e403f.tar.xz
umapyai-4753f8d5198f76398c9bf24e2c8952a6cf0e403f.zip
feat(umapyai): Improve chunk search
-rw-r--r--pyproject.toml1
-rw-r--r--requirements-dev.lock70
-rw-r--r--requirements.lock70
-rw-r--r--src/umapyai/__init__.py75
4 files changed, 196 insertions, 20 deletions
diff --git a/pyproject.toml b/pyproject.toml
index d0c7537..d378722 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -13,6 +13,7 @@ dependencies = [
"flask>=3.1.1",
"flask-cors>=6.0.1",
"flask-sock>=0.7.0",
+ "spacy>=3.8.7",
]
readme = "README.md"
requires-python = ">= 3.8"
diff --git a/requirements-dev.lock b/requirements-dev.lock
index 7f07737..6781bb3 100644
--- a/requirements-dev.lock
+++ b/requirements-dev.lock
@@ -26,10 +26,16 @@ beautifulsoup4==4.13.4
# via umapyai
blinker==1.9.0
# via flask
+blis==1.3.0
+ # via thinc
build==1.2.2.post1
# via chromadb
cachetools==5.5.2
# via google-auth
+catalogue==2.0.10
+ # via spacy
+ # via srsly
+ # via thinc
certifi==2025.7.14
# via httpcore
# via httpx
@@ -43,8 +49,17 @@ click==8.2.1
# via flask
# via typer
# via uvicorn
+cloudpathlib==0.21.1
+ # via weasel
coloredlogs==15.0.1
# via onnxruntime
+confection==0.1.5
+ # via thinc
+ # via weasel
+cymem==2.0.11
+ # via preshed
+ # via spacy
+ # via thinc
distro==1.9.0
# via posthog
durationpy==0.10
@@ -103,6 +118,7 @@ itsdangerous==2.2.0
# via flask
jinja2==3.1.6
# via flask
+ # via spacy
# via torch
joblib==1.5.1
# via scikit-learn
@@ -112,8 +128,14 @@ jsonschema-specifications==2025.4.1
# via jsonschema
kubernetes==33.1.0
# via chromadb
+langcodes==3.5.0
+ # via spacy
+language-data==1.3.0
+ # via langcodes
loguru==0.7.3
# via umapyai
+marisa-trie==1.2.1
+ # via language-data
markdown-it-py==3.0.0
# via rich
markupsafe==3.0.2
@@ -126,13 +148,20 @@ mmh3==5.1.0
# via chromadb
mpmath==1.3.0
# via sympy
+murmurhash==1.0.13
+ # via preshed
+ # via spacy
+ # via thinc
networkx==3.5
# via torch
numpy==2.3.2
+ # via blis
# via chromadb
# via onnxruntime
# via scikit-learn
# via scipy
+ # via spacy
+ # via thinc
# via transformers
oauthlib==3.3.1
# via kubernetes
@@ -164,13 +193,19 @@ packaging==25.0
# via build
# via huggingface-hub
# via onnxruntime
+ # via spacy
+ # via thinc
# via transformers
+ # via weasel
pillow==11.3.0
# via sentence-transformers
platformdirs==4.3.8
# via yapf
posthog==5.4.0
# via chromadb
+preshed==3.0.10
+ # via spacy
+ # via thinc
protobuf==6.31.1
# via googleapis-common-protos
# via onnxruntime
@@ -186,6 +221,10 @@ pybase64==1.4.2
# via chromadb
pydantic==2.11.7
# via chromadb
+ # via confection
+ # via spacy
+ # via thinc
+ # via weasel
pydantic-core==2.33.2
# via pydantic
pygments==2.19.2
@@ -215,8 +254,10 @@ requests==2.32.4
# via kubernetes
# via posthog
# via requests-oauthlib
+ # via spacy
# via transformers
# via umapyai
+ # via weasel
requests-oauthlib==2.0.0
# via kubernetes
rich==14.1.0
@@ -238,6 +279,9 @@ scipy==1.16.1
sentence-transformers==5.0.0
# via umapyai
setuptools==80.9.0
+ # via marisa-trie
+ # via spacy
+ # via thinc
# via torch
shellingham==1.5.4
# via typer
@@ -247,15 +291,30 @@ six==1.17.0
# via kubernetes
# via posthog
# via python-dateutil
+smart-open==7.3.0.post1
+ # via weasel
sniffio==1.3.1
# via anyio
soupsieve==2.7
# via beautifulsoup4
+spacy==3.8.7
+ # via umapyai
+spacy-legacy==3.0.12
+ # via spacy
+spacy-loggers==1.0.5
+ # via spacy
+srsly==2.5.1
+ # via confection
+ # via spacy
+ # via thinc
+ # via weasel
sympy==1.14.0
# via onnxruntime
# via torch
tenacity==9.1.2
# via chromadb
+thinc==8.3.6
+ # via spacy
threadpoolctl==3.6.0
# via scikit-learn
tokenizers==0.21.2
@@ -267,11 +326,14 @@ tqdm==4.67.1
# via chromadb
# via huggingface-hub
# via sentence-transformers
+ # via spacy
# via transformers
transformers==4.54.0
# via sentence-transformers
typer==0.16.0
# via chromadb
+ # via spacy
+ # via weasel
typing-extensions==4.14.1
# via beautifulsoup4
# via chromadb
@@ -295,8 +357,14 @@ uvicorn==0.35.0
# via chromadb
uvloop==0.21.0
# via uvicorn
+wasabi==1.1.3
+ # via spacy
+ # via thinc
+ # via weasel
watchfiles==1.1.0
# via uvicorn
+weasel==0.4.1
+ # via spacy
websocket-client==1.8.0
# via kubernetes
websockets==15.0.1
@@ -304,6 +372,8 @@ websockets==15.0.1
werkzeug==3.1.3
# via flask
# via flask-cors
+wrapt==1.17.2
+ # via smart-open
wsproto==1.2.0
# via simple-websocket
yapf==0.43.0
diff --git a/requirements.lock b/requirements.lock
index ab8398b..bd090a9 100644
--- a/requirements.lock
+++ b/requirements.lock
@@ -26,10 +26,16 @@ beautifulsoup4==4.13.4
# via umapyai
blinker==1.9.0
# via flask
+blis==1.3.0
+ # via thinc
build==1.2.2.post1
# via chromadb
cachetools==5.5.2
# via google-auth
+catalogue==2.0.10
+ # via spacy
+ # via srsly
+ # via thinc
certifi==2025.7.14
# via httpcore
# via httpx
@@ -43,8 +49,17 @@ click==8.2.1
# via flask
# via typer
# via uvicorn
+cloudpathlib==0.21.1
+ # via weasel
coloredlogs==15.0.1
# via onnxruntime
+confection==0.1.5
+ # via thinc
+ # via weasel
+cymem==2.0.11
+ # via preshed
+ # via spacy
+ # via thinc
distro==1.9.0
# via posthog
durationpy==0.10
@@ -103,6 +118,7 @@ itsdangerous==2.2.0
# via flask
jinja2==3.1.6
# via flask
+ # via spacy
# via torch
joblib==1.5.1
# via scikit-learn
@@ -112,8 +128,14 @@ jsonschema-specifications==2025.4.1
# via jsonschema
kubernetes==33.1.0
# via chromadb
+langcodes==3.5.0
+ # via spacy
+language-data==1.3.0
+ # via langcodes
loguru==0.7.3
# via umapyai
+marisa-trie==1.2.1
+ # via language-data
markdown-it-py==3.0.0
# via rich
markupsafe==3.0.2
@@ -126,13 +148,20 @@ mmh3==5.1.0
# via chromadb
mpmath==1.3.0
# via sympy
+murmurhash==1.0.13
+ # via preshed
+ # via spacy
+ # via thinc
networkx==3.5
# via torch
numpy==2.3.2
+ # via blis
# via chromadb
# via onnxruntime
# via scikit-learn
# via scipy
+ # via spacy
+ # via thinc
# via transformers
oauthlib==3.3.1
# via kubernetes
@@ -164,11 +193,17 @@ packaging==25.0
# via build
# via huggingface-hub
# via onnxruntime
+ # via spacy
+ # via thinc
# via transformers
+ # via weasel
pillow==11.3.0
# via sentence-transformers
posthog==5.4.0
# via chromadb
+preshed==3.0.10
+ # via spacy
+ # via thinc
protobuf==6.31.1
# via googleapis-common-protos
# via onnxruntime
@@ -184,6 +219,10 @@ pybase64==1.4.2
# via chromadb
pydantic==2.11.7
# via chromadb
+ # via confection
+ # via spacy
+ # via thinc
+ # via weasel
pydantic-core==2.33.2
# via pydantic
pygments==2.19.2
@@ -213,8 +252,10 @@ requests==2.32.4
# via kubernetes
# via posthog
# via requests-oauthlib
+ # via spacy
# via transformers
# via umapyai
+ # via weasel
requests-oauthlib==2.0.0
# via kubernetes
rich==14.1.0
@@ -235,6 +276,9 @@ scipy==1.16.1
sentence-transformers==5.0.0
# via umapyai
setuptools==80.9.0
+ # via marisa-trie
+ # via spacy
+ # via thinc
# via torch
shellingham==1.5.4
# via typer
@@ -244,15 +288,30 @@ six==1.17.0
# via kubernetes
# via posthog
# via python-dateutil
+smart-open==7.3.0.post1
+ # via weasel
sniffio==1.3.1
# via anyio
soupsieve==2.7
# via beautifulsoup4
+spacy==3.8.7
+ # via umapyai
+spacy-legacy==3.0.12
+ # via spacy
+spacy-loggers==1.0.5
+ # via spacy
+srsly==2.5.1
+ # via confection
+ # via spacy
+ # via thinc
+ # via weasel
sympy==1.14.0
# via onnxruntime
# via torch
tenacity==9.1.2
# via chromadb
+thinc==8.3.6
+ # via spacy
threadpoolctl==3.6.0
# via scikit-learn
tokenizers==0.21.2
@@ -264,11 +323,14 @@ tqdm==4.67.1
# via chromadb
# via huggingface-hub
# via sentence-transformers
+ # via spacy
# via transformers
transformers==4.54.0
# via sentence-transformers
typer==0.16.0
# via chromadb
+ # via spacy
+ # via weasel
typing-extensions==4.14.1
# via beautifulsoup4
# via chromadb
@@ -292,8 +354,14 @@ uvicorn==0.35.0
# via chromadb
uvloop==0.21.0
# via uvicorn
+wasabi==1.1.3
+ # via spacy
+ # via thinc
+ # via weasel
watchfiles==1.1.0
# via uvicorn
+weasel==0.4.1
+ # via spacy
websocket-client==1.8.0
# via kubernetes
websockets==15.0.1
@@ -301,6 +369,8 @@ websockets==15.0.1
werkzeug==3.1.3
# via flask
# via flask-cors
+wrapt==1.17.2
+ # via smart-open
wsproto==1.2.0
# via simple-websocket
zipp==3.23.0
diff --git a/src/umapyai/__init__.py b/src/umapyai/__init__.py
index 3111c00..fd1a2c7 100644
--- a/src/umapyai/__init__.py
+++ b/src/umapyai/__init__.py
@@ -13,6 +13,9 @@ from .constants import (ARTICLES_DIRECTORY, CHROMA_DIRECTORY, CHROMA_COLLECTION,
CHUNK_SIZE, EMBEDDING_MODEL, OLLAMA_MODEL, TOP_K,
OLLAMA_URL)
from .ollama import start_ollama_server, is_ollama_live, ensure_model_pulled, kill_ollama
+import spacy
+import re
+from collections import defaultdict
logger.remove()
logger.add(
@@ -23,11 +26,11 @@ logger.add(
)
app = Flask(__name__)
+socket = Sock(app)
+language = spacy.load("en_core_web_sm")
CORS(app)
-socket = Sock(app)
-
def normalize(text):
return text.replace('_', ' ').replace('-', ' ').replace('.txt',
@@ -49,6 +52,30 @@ def get_significant_filename_parts(filename):
return set(n_grams)
+
+def get_query_phrases(query):
+ document = language(normalize(query))
+ words = [
+ token.text for token in document if not token.is_stop and token.is_alpha
+ ]
+ phrases = set()
+
+ for chunk in document.noun_chunks:
+ phrases.add(chunk.text.lower().strip())
+
+ for ent in document.ents:
+ phrases.add(ent.text.lower().strip())
+
+ for n in [2, 3]:
+ for i in range(len(words) - n + 1):
+ phrases.add(' '.join(words[i:i + n]))
+
+ for word in words:
+ phrases.add(word)
+
+ return {phrase for phrase in phrases if len(phrase) > 2}
+
+
def prompt(rag_context, user_query, is_first_turn=True):
if is_first_turn:
system_prompt = (
@@ -170,33 +197,41 @@ def main():
"source": chunk["source"]
}])
+ def clean_for_match(text):
+ return re.sub(r'\W+', ' ', text.lower()).strip()
+
def find_relevant_chunks(query, top_k=TOP_K):
- normalised_query = normalize(query)
- forced_matches = []
+ query_phrases = get_query_phrases(query)
+ document_match_count = defaultdict(int)
+ document_best_chunk = {}
for chunk in chunks:
- parts = get_significant_filename_parts(chunk["source"])
-
- if any(part in normalised_query for part in parts if len(part) > 2):
- forced_matches.append((chunk["chunk"], {"source": chunk["source"]}))
-
- unique_sources = set()
- deduped_forced = []
-
- for document, metadata in forced_matches:
- if metadata['source'] not in unique_sources:
- deduped_forced.append((document, metadata))
- unique_sources.add(metadata['source'])
-
+ chunk_text = clean_for_match(chunk["chunk"])
+ source = chunk["source"]
+
+ for query_phrase in query_phrases:
+ if clean_for_match(query_phrase) in chunk_text:
+ document_match_count[source] += 1
+
+ if source not in document_best_chunk:
+ document_best_chunk[source] = (chunk["chunk"], {"source": source})
+
+ sorted_sources = sorted(
+ document_match_count, key=lambda k: -document_match_count[k])
+ deduped_forced = [
+ document_best_chunk[source]
+ for source in sorted_sources
+ if document_match_count[source] > 0
+ ]
q_embed = model.encode(query)
results = collection.query(
query_embeddings=[q_embed.tolist()], n_results=top_k * 2)
- semantic_docs = [(document, metadata) for document, metadata in zip(
+ semantic_documents = [(document, metadata) for document, metadata in zip(
results['documents'][0], results['metadatas'][0])]
+ seen = set(metadata['source'] for _, metadata in deduped_forced)
merged = deduped_forced.copy()
- seen = set(meta['source'] for _, meta in deduped_forced)
- for document, metadata in semantic_docs:
+ for document, metadata in semantic_documents:
if metadata['source'] not in seen and len(merged) < top_k:
merged.append((document, metadata))
seen.add(metadata['source'])