aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFuwn <[email protected]>2025-07-29 12:41:03 +0200
committerFuwn <[email protected]>2025-07-29 12:41:03 +0200
commit1200520fa84373a565fe31a557715d25dfc216e5 (patch)
tree15d2f1ee8146ef0b0c0354d4204e2ed326d4c3a5
parentfeat(umapyai): Response streaming (diff)
downloadumapyai-1200520fa84373a565fe31a557715d25dfc216e5.tar.xz
umapyai-1200520fa84373a565fe31a557715d25dfc216e5.zip
feat(umapyai): Improve chunk search
-rw-r--r--src/umapyai/__init__.py55
1 files changed, 49 insertions, 6 deletions
diff --git a/src/umapyai/__init__.py b/src/umapyai/__init__.py
index 348fd23..3111c00 100644
--- a/src/umapyai/__init__.py
+++ b/src/umapyai/__init__.py
@@ -29,6 +29,26 @@ CORS(app)
socket = Sock(app)
+def normalize(text):
+ return text.replace('_', ' ').replace('-', ' ').replace('.txt',
+ '').lower().strip()
+
+
+def get_significant_filename_parts(filename):
+ normalised = normalize(filename)
+ words = normalised.split()
+ n_grams = [' '.join(words[i:i + 2]) for i in range(len(words) - 1)]
+
+ if len(words) >= 2:
+ n_grams.append(' '.join(words[:2]))
+
+ if len(words) >= 3:
+ n_grams.append(' '.join(words[:3]))
+
+ n_grams.append(words[0])
+
+ return set(n_grams)
+
def prompt(rag_context, user_query, is_first_turn=True):
if is_first_turn:
system_prompt = (
@@ -151,14 +171,37 @@ def main():
}])
def find_relevant_chunks(query, top_k=TOP_K):
+ normalised_query = normalize(query)
+ forced_matches = []
+
+ for chunk in chunks:
+ parts = get_significant_filename_parts(chunk["source"])
+
+ if any(part in normalised_query for part in parts if len(part) > 2):
+ forced_matches.append((chunk["chunk"], {"source": chunk["source"]}))
+
+ unique_sources = set()
+ deduped_forced = []
+
+ for document, metadata in forced_matches:
+ if metadata['source'] not in unique_sources:
+ deduped_forced.append((document, metadata))
+ unique_sources.add(metadata['source'])
+
q_embed = model.encode(query)
results = collection.query(
- query_embeddings=[q_embed.tolist()], n_results=top_k)
- documents = results['documents'][0]
- metadatas = results['metadatas'][0]
-
- return [(document, metadata)
- for document, metadata in zip(documents, metadatas)]
+ query_embeddings=[q_embed.tolist()], n_results=top_k * 2)
+ semantic_docs = [(document, metadata) for document, metadata in zip(
+ results['documents'][0], results['metadatas'][0])]
+ merged = deduped_forced.copy()
+ seen = set(meta['source'] for _, meta in deduped_forced)
+
+ for document, metadata in semantic_docs:
+ if metadata['source'] not in seen and len(merged) < top_k:
+ merged.append((document, metadata))
+ seen.add(metadata['source'])
+
+ return merged[:top_k]
def query_ollama(prompt, context=None):
url = f"{OLLAMA_URL}/api/generate"