aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorFuwn <[email protected]>2025-07-29 13:25:56 +0200
committerFuwn <[email protected]>2025-07-29 13:25:56 +0200
commit4753f8d5198f76398c9bf24e2c8952a6cf0e403f (patch)
treeaba36833b661691edc76cf26357e3aed03d971c8 /src
parentfeat(umapyai): Improve chunk search (diff)
downloadumapyai-4753f8d5198f76398c9bf24e2c8952a6cf0e403f.tar.xz
umapyai-4753f8d5198f76398c9bf24e2c8952a6cf0e403f.zip
feat(umapyai): Improve chunk search
Diffstat (limited to 'src')
-rw-r--r--src/umapyai/__init__.py75
1 files changed, 55 insertions, 20 deletions
diff --git a/src/umapyai/__init__.py b/src/umapyai/__init__.py
index 3111c00..fd1a2c7 100644
--- a/src/umapyai/__init__.py
+++ b/src/umapyai/__init__.py
@@ -13,6 +13,9 @@ from .constants import (ARTICLES_DIRECTORY, CHROMA_DIRECTORY, CHROMA_COLLECTION,
CHUNK_SIZE, EMBEDDING_MODEL, OLLAMA_MODEL, TOP_K,
OLLAMA_URL)
from .ollama import start_ollama_server, is_ollama_live, ensure_model_pulled, kill_ollama
+import spacy
+import re
+from collections import defaultdict
logger.remove()
logger.add(
@@ -23,11 +26,11 @@ logger.add(
)
app = Flask(__name__)
+socket = Sock(app)
+language = spacy.load("en_core_web_sm")
CORS(app)
-socket = Sock(app)
-
def normalize(text):
return text.replace('_', ' ').replace('-', ' ').replace('.txt',
@@ -49,6 +52,30 @@ def get_significant_filename_parts(filename):
return set(n_grams)
+
+def get_query_phrases(query):
+ document = language(normalize(query))
+ words = [
+ token.text for token in document if not token.is_stop and token.is_alpha
+ ]
+ phrases = set()
+
+ for chunk in document.noun_chunks:
+ phrases.add(chunk.text.lower().strip())
+
+ for ent in document.ents:
+ phrases.add(ent.text.lower().strip())
+
+ for n in [2, 3]:
+ for i in range(len(words) - n + 1):
+ phrases.add(' '.join(words[i:i + n]))
+
+ for word in words:
+ phrases.add(word)
+
+ return {phrase for phrase in phrases if len(phrase) > 2}
+
+
def prompt(rag_context, user_query, is_first_turn=True):
if is_first_turn:
system_prompt = (
@@ -170,33 +197,41 @@ def main():
"source": chunk["source"]
}])
+ def clean_for_match(text):
+ return re.sub(r'\W+', ' ', text.lower()).strip()
+
def find_relevant_chunks(query, top_k=TOP_K):
- normalised_query = normalize(query)
- forced_matches = []
+ query_phrases = get_query_phrases(query)
+ document_match_count = defaultdict(int)
+ document_best_chunk = {}
for chunk in chunks:
- parts = get_significant_filename_parts(chunk["source"])
-
- if any(part in normalised_query for part in parts if len(part) > 2):
- forced_matches.append((chunk["chunk"], {"source": chunk["source"]}))
-
- unique_sources = set()
- deduped_forced = []
-
- for document, metadata in forced_matches:
- if metadata['source'] not in unique_sources:
- deduped_forced.append((document, metadata))
- unique_sources.add(metadata['source'])
-
+ chunk_text = clean_for_match(chunk["chunk"])
+ source = chunk["source"]
+
+ for query_phrase in query_phrases:
+ if clean_for_match(query_phrase) in chunk_text:
+ document_match_count[source] += 1
+
+ if source not in document_best_chunk:
+ document_best_chunk[source] = (chunk["chunk"], {"source": source})
+
+ sorted_sources = sorted(
+ document_match_count, key=lambda k: -document_match_count[k])
+ deduped_forced = [
+ document_best_chunk[source]
+ for source in sorted_sources
+ if document_match_count[source] > 0
+ ]
q_embed = model.encode(query)
results = collection.query(
query_embeddings=[q_embed.tolist()], n_results=top_k * 2)
- semantic_docs = [(document, metadata) for document, metadata in zip(
+ semantic_documents = [(document, metadata) for document, metadata in zip(
results['documents'][0], results['metadatas'][0])]
+ seen = set(metadata['source'] for _, metadata in deduped_forced)
merged = deduped_forced.copy()
- seen = set(meta['source'] for _, meta in deduped_forced)
- for document, metadata in semantic_docs:
+ for document, metadata in semantic_documents:
if metadata['source'] not in seen and len(merged) < top_k:
merged.append((document, metadata))
seen.add(metadata['source'])