diff options
| author | Fuwn <[email protected]> | 2025-07-29 13:25:56 +0200 |
|---|---|---|
| committer | Fuwn <[email protected]> | 2025-07-29 13:25:56 +0200 |
| commit | 4753f8d5198f76398c9bf24e2c8952a6cf0e403f (patch) | |
| tree | aba36833b661691edc76cf26357e3aed03d971c8 /src | |
| parent | feat(umapyai): Improve chunk search (diff) | |
| download | umapyai-4753f8d5198f76398c9bf24e2c8952a6cf0e403f.tar.xz umapyai-4753f8d5198f76398c9bf24e2c8952a6cf0e403f.zip | |
feat(umapyai): Improve chunk search
Diffstat (limited to 'src')
| -rw-r--r-- | src/umapyai/__init__.py | 75 |
1 files changed, 55 insertions, 20 deletions
diff --git a/src/umapyai/__init__.py b/src/umapyai/__init__.py index 3111c00..fd1a2c7 100644 --- a/src/umapyai/__init__.py +++ b/src/umapyai/__init__.py @@ -13,6 +13,9 @@ from .constants import (ARTICLES_DIRECTORY, CHROMA_DIRECTORY, CHROMA_COLLECTION, CHUNK_SIZE, EMBEDDING_MODEL, OLLAMA_MODEL, TOP_K, OLLAMA_URL) from .ollama import start_ollama_server, is_ollama_live, ensure_model_pulled, kill_ollama +import spacy +import re +from collections import defaultdict logger.remove() logger.add( @@ -23,11 +26,11 @@ logger.add( ) app = Flask(__name__) +socket = Sock(app) +language = spacy.load("en_core_web_sm") CORS(app) -socket = Sock(app) - def normalize(text): return text.replace('_', ' ').replace('-', ' ').replace('.txt', @@ -49,6 +52,30 @@ def get_significant_filename_parts(filename): return set(n_grams) + +def get_query_phrases(query): + document = language(normalize(query)) + words = [ + token.text for token in document if not token.is_stop and token.is_alpha + ] + phrases = set() + + for chunk in document.noun_chunks: + phrases.add(chunk.text.lower().strip()) + + for ent in document.ents: + phrases.add(ent.text.lower().strip()) + + for n in [2, 3]: + for i in range(len(words) - n + 1): + phrases.add(' '.join(words[i:i + n])) + + for word in words: + phrases.add(word) + + return {phrase for phrase in phrases if len(phrase) > 2} + + def prompt(rag_context, user_query, is_first_turn=True): if is_first_turn: system_prompt = ( @@ -170,33 +197,41 @@ def main(): "source": chunk["source"] }]) + def clean_for_match(text): + return re.sub(r'\W+', ' ', text.lower()).strip() + def find_relevant_chunks(query, top_k=TOP_K): - normalised_query = normalize(query) - forced_matches = [] + query_phrases = get_query_phrases(query) + document_match_count = defaultdict(int) + document_best_chunk = {} for chunk in chunks: - parts = get_significant_filename_parts(chunk["source"]) - - if any(part in normalised_query for part in parts if len(part) > 2): - forced_matches.append((chunk["chunk"], {"source": chunk["source"]})) - - unique_sources = set() - deduped_forced = [] - - for document, metadata in forced_matches: - if metadata['source'] not in unique_sources: - deduped_forced.append((document, metadata)) - unique_sources.add(metadata['source']) - + chunk_text = clean_for_match(chunk["chunk"]) + source = chunk["source"] + + for query_phrase in query_phrases: + if clean_for_match(query_phrase) in chunk_text: + document_match_count[source] += 1 + + if source not in document_best_chunk: + document_best_chunk[source] = (chunk["chunk"], {"source": source}) + + sorted_sources = sorted( + document_match_count, key=lambda k: -document_match_count[k]) + deduped_forced = [ + document_best_chunk[source] + for source in sorted_sources + if document_match_count[source] > 0 + ] q_embed = model.encode(query) results = collection.query( query_embeddings=[q_embed.tolist()], n_results=top_k * 2) - semantic_docs = [(document, metadata) for document, metadata in zip( + semantic_documents = [(document, metadata) for document, metadata in zip( results['documents'][0], results['metadatas'][0])] + seen = set(metadata['source'] for _, metadata in deduped_forced) merged = deduped_forced.copy() - seen = set(meta['source'] for _, meta in deduped_forced) - for document, metadata in semantic_docs: + for document, metadata in semantic_documents: if metadata['source'] not in seen and len(merged) < top_k: merged.append((document, metadata)) seen.add(metadata['source']) |