diff options
| author | Fuwn <[email protected]> | 2025-07-29 13:25:56 +0200 |
|---|---|---|
| committer | Fuwn <[email protected]> | 2025-07-29 13:25:56 +0200 |
| commit | 4753f8d5198f76398c9bf24e2c8952a6cf0e403f (patch) | |
| tree | aba36833b661691edc76cf26357e3aed03d971c8 | |
| parent | feat(umapyai): Improve chunk search (diff) | |
| download | umapyai-4753f8d5198f76398c9bf24e2c8952a6cf0e403f.tar.xz umapyai-4753f8d5198f76398c9bf24e2c8952a6cf0e403f.zip | |
feat(umapyai): Improve chunk search
| -rw-r--r-- | pyproject.toml | 1 | ||||
| -rw-r--r-- | requirements-dev.lock | 70 | ||||
| -rw-r--r-- | requirements.lock | 70 | ||||
| -rw-r--r-- | src/umapyai/__init__.py | 75 |
4 files changed, 196 insertions, 20 deletions
diff --git a/pyproject.toml b/pyproject.toml index d0c7537..d378722 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "flask>=3.1.1", "flask-cors>=6.0.1", "flask-sock>=0.7.0", + "spacy>=3.8.7", ] readme = "README.md" requires-python = ">= 3.8" diff --git a/requirements-dev.lock b/requirements-dev.lock index 7f07737..6781bb3 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -26,10 +26,16 @@ beautifulsoup4==4.13.4 # via umapyai blinker==1.9.0 # via flask +blis==1.3.0 + # via thinc build==1.2.2.post1 # via chromadb cachetools==5.5.2 # via google-auth +catalogue==2.0.10 + # via spacy + # via srsly + # via thinc certifi==2025.7.14 # via httpcore # via httpx @@ -43,8 +49,17 @@ click==8.2.1 # via flask # via typer # via uvicorn +cloudpathlib==0.21.1 + # via weasel coloredlogs==15.0.1 # via onnxruntime +confection==0.1.5 + # via thinc + # via weasel +cymem==2.0.11 + # via preshed + # via spacy + # via thinc distro==1.9.0 # via posthog durationpy==0.10 @@ -103,6 +118,7 @@ itsdangerous==2.2.0 # via flask jinja2==3.1.6 # via flask + # via spacy # via torch joblib==1.5.1 # via scikit-learn @@ -112,8 +128,14 @@ jsonschema-specifications==2025.4.1 # via jsonschema kubernetes==33.1.0 # via chromadb +langcodes==3.5.0 + # via spacy +language-data==1.3.0 + # via langcodes loguru==0.7.3 # via umapyai +marisa-trie==1.2.1 + # via language-data markdown-it-py==3.0.0 # via rich markupsafe==3.0.2 @@ -126,13 +148,20 @@ mmh3==5.1.0 # via chromadb mpmath==1.3.0 # via sympy +murmurhash==1.0.13 + # via preshed + # via spacy + # via thinc networkx==3.5 # via torch numpy==2.3.2 + # via blis # via chromadb # via onnxruntime # via scikit-learn # via scipy + # via spacy + # via thinc # via transformers oauthlib==3.3.1 # via kubernetes @@ -164,13 +193,19 @@ packaging==25.0 # via build # via huggingface-hub # via onnxruntime + # via spacy + # via thinc # via transformers + # via weasel pillow==11.3.0 # via sentence-transformers platformdirs==4.3.8 # via yapf posthog==5.4.0 # via chromadb +preshed==3.0.10 + # via spacy + # via thinc protobuf==6.31.1 # via googleapis-common-protos # via onnxruntime @@ -186,6 +221,10 @@ pybase64==1.4.2 # via chromadb pydantic==2.11.7 # via chromadb + # via confection + # via spacy + # via thinc + # via weasel pydantic-core==2.33.2 # via pydantic pygments==2.19.2 @@ -215,8 +254,10 @@ requests==2.32.4 # via kubernetes # via posthog # via requests-oauthlib + # via spacy # via transformers # via umapyai + # via weasel requests-oauthlib==2.0.0 # via kubernetes rich==14.1.0 @@ -238,6 +279,9 @@ scipy==1.16.1 sentence-transformers==5.0.0 # via umapyai setuptools==80.9.0 + # via marisa-trie + # via spacy + # via thinc # via torch shellingham==1.5.4 # via typer @@ -247,15 +291,30 @@ six==1.17.0 # via kubernetes # via posthog # via python-dateutil +smart-open==7.3.0.post1 + # via weasel sniffio==1.3.1 # via anyio soupsieve==2.7 # via beautifulsoup4 +spacy==3.8.7 + # via umapyai +spacy-legacy==3.0.12 + # via spacy +spacy-loggers==1.0.5 + # via spacy +srsly==2.5.1 + # via confection + # via spacy + # via thinc + # via weasel sympy==1.14.0 # via onnxruntime # via torch tenacity==9.1.2 # via chromadb +thinc==8.3.6 + # via spacy threadpoolctl==3.6.0 # via scikit-learn tokenizers==0.21.2 @@ -267,11 +326,14 @@ tqdm==4.67.1 # via chromadb # via huggingface-hub # via sentence-transformers + # via spacy # via transformers transformers==4.54.0 # via sentence-transformers typer==0.16.0 # via chromadb + # via spacy + # via weasel typing-extensions==4.14.1 # via beautifulsoup4 # via chromadb @@ -295,8 +357,14 @@ uvicorn==0.35.0 # via chromadb uvloop==0.21.0 # via uvicorn +wasabi==1.1.3 + # via spacy + # via thinc + # via weasel watchfiles==1.1.0 # via uvicorn +weasel==0.4.1 + # via spacy websocket-client==1.8.0 # via kubernetes websockets==15.0.1 @@ -304,6 +372,8 @@ websockets==15.0.1 werkzeug==3.1.3 # via flask # via flask-cors +wrapt==1.17.2 + # via smart-open wsproto==1.2.0 # via simple-websocket yapf==0.43.0 diff --git a/requirements.lock b/requirements.lock index ab8398b..bd090a9 100644 --- a/requirements.lock +++ b/requirements.lock @@ -26,10 +26,16 @@ beautifulsoup4==4.13.4 # via umapyai blinker==1.9.0 # via flask +blis==1.3.0 + # via thinc build==1.2.2.post1 # via chromadb cachetools==5.5.2 # via google-auth +catalogue==2.0.10 + # via spacy + # via srsly + # via thinc certifi==2025.7.14 # via httpcore # via httpx @@ -43,8 +49,17 @@ click==8.2.1 # via flask # via typer # via uvicorn +cloudpathlib==0.21.1 + # via weasel coloredlogs==15.0.1 # via onnxruntime +confection==0.1.5 + # via thinc + # via weasel +cymem==2.0.11 + # via preshed + # via spacy + # via thinc distro==1.9.0 # via posthog durationpy==0.10 @@ -103,6 +118,7 @@ itsdangerous==2.2.0 # via flask jinja2==3.1.6 # via flask + # via spacy # via torch joblib==1.5.1 # via scikit-learn @@ -112,8 +128,14 @@ jsonschema-specifications==2025.4.1 # via jsonschema kubernetes==33.1.0 # via chromadb +langcodes==3.5.0 + # via spacy +language-data==1.3.0 + # via langcodes loguru==0.7.3 # via umapyai +marisa-trie==1.2.1 + # via language-data markdown-it-py==3.0.0 # via rich markupsafe==3.0.2 @@ -126,13 +148,20 @@ mmh3==5.1.0 # via chromadb mpmath==1.3.0 # via sympy +murmurhash==1.0.13 + # via preshed + # via spacy + # via thinc networkx==3.5 # via torch numpy==2.3.2 + # via blis # via chromadb # via onnxruntime # via scikit-learn # via scipy + # via spacy + # via thinc # via transformers oauthlib==3.3.1 # via kubernetes @@ -164,11 +193,17 @@ packaging==25.0 # via build # via huggingface-hub # via onnxruntime + # via spacy + # via thinc # via transformers + # via weasel pillow==11.3.0 # via sentence-transformers posthog==5.4.0 # via chromadb +preshed==3.0.10 + # via spacy + # via thinc protobuf==6.31.1 # via googleapis-common-protos # via onnxruntime @@ -184,6 +219,10 @@ pybase64==1.4.2 # via chromadb pydantic==2.11.7 # via chromadb + # via confection + # via spacy + # via thinc + # via weasel pydantic-core==2.33.2 # via pydantic pygments==2.19.2 @@ -213,8 +252,10 @@ requests==2.32.4 # via kubernetes # via posthog # via requests-oauthlib + # via spacy # via transformers # via umapyai + # via weasel requests-oauthlib==2.0.0 # via kubernetes rich==14.1.0 @@ -235,6 +276,9 @@ scipy==1.16.1 sentence-transformers==5.0.0 # via umapyai setuptools==80.9.0 + # via marisa-trie + # via spacy + # via thinc # via torch shellingham==1.5.4 # via typer @@ -244,15 +288,30 @@ six==1.17.0 # via kubernetes # via posthog # via python-dateutil +smart-open==7.3.0.post1 + # via weasel sniffio==1.3.1 # via anyio soupsieve==2.7 # via beautifulsoup4 +spacy==3.8.7 + # via umapyai +spacy-legacy==3.0.12 + # via spacy +spacy-loggers==1.0.5 + # via spacy +srsly==2.5.1 + # via confection + # via spacy + # via thinc + # via weasel sympy==1.14.0 # via onnxruntime # via torch tenacity==9.1.2 # via chromadb +thinc==8.3.6 + # via spacy threadpoolctl==3.6.0 # via scikit-learn tokenizers==0.21.2 @@ -264,11 +323,14 @@ tqdm==4.67.1 # via chromadb # via huggingface-hub # via sentence-transformers + # via spacy # via transformers transformers==4.54.0 # via sentence-transformers typer==0.16.0 # via chromadb + # via spacy + # via weasel typing-extensions==4.14.1 # via beautifulsoup4 # via chromadb @@ -292,8 +354,14 @@ uvicorn==0.35.0 # via chromadb uvloop==0.21.0 # via uvicorn +wasabi==1.1.3 + # via spacy + # via thinc + # via weasel watchfiles==1.1.0 # via uvicorn +weasel==0.4.1 + # via spacy websocket-client==1.8.0 # via kubernetes websockets==15.0.1 @@ -301,6 +369,8 @@ websockets==15.0.1 werkzeug==3.1.3 # via flask # via flask-cors +wrapt==1.17.2 + # via smart-open wsproto==1.2.0 # via simple-websocket zipp==3.23.0 diff --git a/src/umapyai/__init__.py b/src/umapyai/__init__.py index 3111c00..fd1a2c7 100644 --- a/src/umapyai/__init__.py +++ b/src/umapyai/__init__.py @@ -13,6 +13,9 @@ from .constants import (ARTICLES_DIRECTORY, CHROMA_DIRECTORY, CHROMA_COLLECTION, CHUNK_SIZE, EMBEDDING_MODEL, OLLAMA_MODEL, TOP_K, OLLAMA_URL) from .ollama import start_ollama_server, is_ollama_live, ensure_model_pulled, kill_ollama +import spacy +import re +from collections import defaultdict logger.remove() logger.add( @@ -23,11 +26,11 @@ logger.add( ) app = Flask(__name__) +socket = Sock(app) +language = spacy.load("en_core_web_sm") CORS(app) -socket = Sock(app) - def normalize(text): return text.replace('_', ' ').replace('-', ' ').replace('.txt', @@ -49,6 +52,30 @@ def get_significant_filename_parts(filename): return set(n_grams) + +def get_query_phrases(query): + document = language(normalize(query)) + words = [ + token.text for token in document if not token.is_stop and token.is_alpha + ] + phrases = set() + + for chunk in document.noun_chunks: + phrases.add(chunk.text.lower().strip()) + + for ent in document.ents: + phrases.add(ent.text.lower().strip()) + + for n in [2, 3]: + for i in range(len(words) - n + 1): + phrases.add(' '.join(words[i:i + n])) + + for word in words: + phrases.add(word) + + return {phrase for phrase in phrases if len(phrase) > 2} + + def prompt(rag_context, user_query, is_first_turn=True): if is_first_turn: system_prompt = ( @@ -170,33 +197,41 @@ def main(): "source": chunk["source"] }]) + def clean_for_match(text): + return re.sub(r'\W+', ' ', text.lower()).strip() + def find_relevant_chunks(query, top_k=TOP_K): - normalised_query = normalize(query) - forced_matches = [] + query_phrases = get_query_phrases(query) + document_match_count = defaultdict(int) + document_best_chunk = {} for chunk in chunks: - parts = get_significant_filename_parts(chunk["source"]) - - if any(part in normalised_query for part in parts if len(part) > 2): - forced_matches.append((chunk["chunk"], {"source": chunk["source"]})) - - unique_sources = set() - deduped_forced = [] - - for document, metadata in forced_matches: - if metadata['source'] not in unique_sources: - deduped_forced.append((document, metadata)) - unique_sources.add(metadata['source']) - + chunk_text = clean_for_match(chunk["chunk"]) + source = chunk["source"] + + for query_phrase in query_phrases: + if clean_for_match(query_phrase) in chunk_text: + document_match_count[source] += 1 + + if source not in document_best_chunk: + document_best_chunk[source] = (chunk["chunk"], {"source": source}) + + sorted_sources = sorted( + document_match_count, key=lambda k: -document_match_count[k]) + deduped_forced = [ + document_best_chunk[source] + for source in sorted_sources + if document_match_count[source] > 0 + ] q_embed = model.encode(query) results = collection.query( query_embeddings=[q_embed.tolist()], n_results=top_k * 2) - semantic_docs = [(document, metadata) for document, metadata in zip( + semantic_documents = [(document, metadata) for document, metadata in zip( results['documents'][0], results['metadatas'][0])] + seen = set(metadata['source'] for _, metadata in deduped_forced) merged = deduped_forced.copy() - seen = set(meta['source'] for _, meta in deduped_forced) - for document, metadata in semantic_docs: + for document, metadata in semantic_documents: if metadata['source'] not in seen and len(merged) < top_k: merged.append((document, metadata)) seen.add(metadata['source']) |