diff options
| author | Fuwn <[email protected]> | 2025-07-29 12:41:03 +0200 |
|---|---|---|
| committer | Fuwn <[email protected]> | 2025-07-29 12:41:03 +0200 |
| commit | 1200520fa84373a565fe31a557715d25dfc216e5 (patch) | |
| tree | 15d2f1ee8146ef0b0c0354d4204e2ed326d4c3a5 | |
| parent | feat(umapyai): Response streaming (diff) | |
| download | umapyai-1200520fa84373a565fe31a557715d25dfc216e5.tar.xz umapyai-1200520fa84373a565fe31a557715d25dfc216e5.zip | |
feat(umapyai): Improve chunk search
| -rw-r--r-- | src/umapyai/__init__.py | 55 |
1 files changed, 49 insertions, 6 deletions
diff --git a/src/umapyai/__init__.py b/src/umapyai/__init__.py index 348fd23..3111c00 100644 --- a/src/umapyai/__init__.py +++ b/src/umapyai/__init__.py @@ -29,6 +29,26 @@ CORS(app) socket = Sock(app) +def normalize(text): + return text.replace('_', ' ').replace('-', ' ').replace('.txt', + '').lower().strip() + + +def get_significant_filename_parts(filename): + normalised = normalize(filename) + words = normalised.split() + n_grams = [' '.join(words[i:i + 2]) for i in range(len(words) - 1)] + + if len(words) >= 2: + n_grams.append(' '.join(words[:2])) + + if len(words) >= 3: + n_grams.append(' '.join(words[:3])) + + n_grams.append(words[0]) + + return set(n_grams) + def prompt(rag_context, user_query, is_first_turn=True): if is_first_turn: system_prompt = ( @@ -151,14 +171,37 @@ def main(): }]) def find_relevant_chunks(query, top_k=TOP_K): + normalised_query = normalize(query) + forced_matches = [] + + for chunk in chunks: + parts = get_significant_filename_parts(chunk["source"]) + + if any(part in normalised_query for part in parts if len(part) > 2): + forced_matches.append((chunk["chunk"], {"source": chunk["source"]})) + + unique_sources = set() + deduped_forced = [] + + for document, metadata in forced_matches: + if metadata['source'] not in unique_sources: + deduped_forced.append((document, metadata)) + unique_sources.add(metadata['source']) + q_embed = model.encode(query) results = collection.query( - query_embeddings=[q_embed.tolist()], n_results=top_k) - documents = results['documents'][0] - metadatas = results['metadatas'][0] - - return [(document, metadata) - for document, metadata in zip(documents, metadatas)] + query_embeddings=[q_embed.tolist()], n_results=top_k * 2) + semantic_docs = [(document, metadata) for document, metadata in zip( + results['documents'][0], results['metadatas'][0])] + merged = deduped_forced.copy() + seen = set(meta['source'] for _, meta in deduped_forced) + + for document, metadata in semantic_docs: + if metadata['source'] not in seen and len(merged) < top_k: + merged.append((document, metadata)) + seen.add(metadata['source']) + + return merged[:top_k] def query_ollama(prompt, context=None): url = f"{OLLAMA_URL}/api/generate" |