diff options
| author | Dhravya Shah <[email protected]> | 2025-02-18 21:51:26 -0700 |
|---|---|---|
| committer | Dhravya Shah <[email protected]> | 2025-02-18 21:51:26 -0700 |
| commit | d5477b4ef3f1486f4de438b1582e80544ff62db0 (patch) | |
| tree | 426eea9bd21d5d3a8163a1c93e95a16628b739da /apps/backend/src | |
| parent | implemented proper hybrid search with date relevancy into consideration (diff) | |
| download | supermemory-d5477b4ef3f1486f4de438b1582e80544ff62db0.tar.xz supermemory-d5477b4ef3f1486f4de438b1582e80544ff62db0.zip | |
hybrid rag looks good now
Diffstat (limited to 'apps/backend/src')
| -rw-r--r-- | apps/backend/src/routes/actions.ts | 319 |
1 files changed, 201 insertions, 118 deletions
diff --git a/apps/backend/src/routes/actions.ts b/apps/backend/src/routes/actions.ts index 0bc26052..37cc18e4 100644 --- a/apps/backend/src/routes/actions.ts +++ b/apps/backend/src/routes/actions.ts @@ -89,7 +89,8 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() }); const googleClient = wrapAISDKModel( - openai(c.env).chat("gpt-4o-mini-2024-07-18")); + openai(c.env).chat("gpt-4o-mini-2024-07-18") + ); // Get last user message and generate embedding in parallel with thread creation let lastUserMessage = coreMessages.findLast((i) => i.role === "user"); @@ -124,96 +125,137 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() return c.json({ error: "Failed to generate embedding" }, 500); } - // Pre-compute the vector similarity expression to avoid multiple calculations - const vectorSimilarity = sql<number>`1 - (embeddings <=> ${JSON.stringify(embedding[0])}::vector)`; - const textSearchRank = sql<number>`ts_rank_cd(( - setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') || - setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') || - setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') || - setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D') - ), plainto_tsquery('english', ${queryText}))`; - - const finalResults = await db - .select({ - id: documents.id, - content: documents.content, - type: documents.type, - url: documents.url, - title: documents.title, - createdAt: documents.createdAt, - updatedAt: documents.updatedAt, - userId: documents.userId, - description: documents.description, - ogImage: documents.ogImage, - similarity: vectorSimilarity, - textRank: textSearchRank, - }) - .from(chunk) - .innerJoin(documents, eq(chunk.documentId, documents.id)) - .where( - and( - eq(documents.userId, user.id), - sql`${vectorSimilarity} > 0.5` - ) - ) - .orderBy( - desc(sql<number>`( - 0.6 * ${vectorSimilarity} + - 0.25 * ${textSearchRank} + - 0.15 * (1.0 / (1.0 + extract(epoch from age(${documents.updatedAt})) / (90 * 24 * 60 * 60))) - )::float`) - ) - .limit(15); - - const cleanDocumentsForContext = finalResults.map((d) => ({ - title: d.title, - description: d.description, - url: d.url, - type: d.type, - content: d.content, - })); - - if (lastUserMessage) { - lastUserMessage.content = - typeof lastUserMessage.content === "string" - ? lastUserMessage.content + - `<context>${JSON.stringify(cleanDocumentsForContext)}</context>` - : [ - ...lastUserMessage.content, - { - type: "text", - text: `<context>${JSON.stringify(cleanDocumentsForContext)}</context>`, - }, - ]; - coreMessages[coreMessages.length - 1] = lastUserMessage; - } - try { const data = new StreamData(); - // De-duplicate chunks by URL to avoid showing duplicate content - const uniqueResults = finalResults.reduce((acc, current) => { - const existingResult = acc.find(item => item.id === current.id); - if (!existingResult) { - acc.push(current); - } - return acc; - }, [] as typeof finalResults); - - data.appendMessageAnnotation( - uniqueResults.map((r) => ({ - id: r.id, - content: r.content, - type: r.type, - url: r.url, - title: r.title, - description: r.description, - ogImage: r.ogImage, - userId: r.userId, - createdAt: r.createdAt.toISOString(), - updatedAt: r.updatedAt?.toISOString() || null, - })) + + // Pre-compute the vector similarity expression + const vectorSimilarity = sql<number>`1 - (embeddings <=> ${JSON.stringify(embedding[0])}::vector)`; + const textSearchRank = sql<number>`ts_rank_cd( + to_tsvector('english', coalesce(${chunk.textContent}, '')), + plainto_tsquery('english', ${queryText}) + )`; + + // Get matching chunks with document info + const matchingChunks = await db + .select({ + chunkId: chunk.id, + documentId: chunk.documentId, + textContent: chunk.textContent, + orderInDocument: chunk.orderInDocument, + metadata: chunk.metadata, + similarity: vectorSimilarity, + textRank: textSearchRank, + // Document fields + docId: documents.id, + docUuid: documents.uuid, + docContent: documents.content, + docType: documents.type, + docUrl: documents.url, + docTitle: documents.title, + docDescription: documents.description, + docOgImage: documents.ogImage, + }) + .from(chunk) + .innerJoin(documents, eq(chunk.documentId, documents.id)) + .where( + and(eq(documents.userId, user.id), sql`${vectorSimilarity} > 0.5`) + ) + .orderBy( + desc(sql<number>`( + 0.6 * ${vectorSimilarity} + + 0.25 * ${textSearchRank} + + 0.15 * (1.0 / (1.0 + extract(epoch from age(${documents.updatedAt})) / (90 * 24 * 60 * 60))) + )::float`) + ) + .limit(15); + + // Get unique document IDs from matching chunks + const uniqueDocIds = [ + ...new Set(matchingChunks.map((c) => c.documentId)), + ]; + + // Fetch all chunks for these documents to get context + const contextChunks = await db + .select({ + id: chunk.id, + documentId: chunk.documentId, + textContent: chunk.textContent, + orderInDocument: chunk.orderInDocument, + metadata: chunk.metadata, + }) + .from(chunk) + .where(inArray(chunk.documentId, uniqueDocIds)) + .orderBy(chunk.documentId, chunk.orderInDocument); + + // Group chunks by document + const chunksByDocument = new Map<number, typeof contextChunks>(); + for (const chunk of contextChunks) { + const docChunks = chunksByDocument.get(chunk.documentId) || []; + docChunks.push(chunk); + chunksByDocument.set(chunk.documentId, docChunks); + } + + // Create context with surrounding chunks + const contextualResults = matchingChunks.map((match) => { + const docChunks = chunksByDocument.get(match.documentId) || []; + const matchIndex = docChunks.findIndex((c) => c.id === match.chunkId); + + // Get surrounding chunks (1 before and 1 after) + const start = Math.max(0, matchIndex - 1); + const end = Math.min(docChunks.length, matchIndex + 2); + const relevantChunks = docChunks.slice(start, end); + + return { + id: match.docId, + title: match.docTitle, + description: match.docDescription, + url: match.docUrl, + type: match.docType, + content: relevantChunks.map((c) => c.textContent).join("\n"), + similarity: Number(match.similarity.toFixed(4)), + chunks: relevantChunks.map((c) => ({ + id: c.id, + content: c.textContent, + orderInDocument: c.orderInDocument, + metadata: c.metadata, + isMatch: c.id === match.chunkId, + })), + }; + }); + + // Remove duplicates based on document ID + const uniqueResults = contextualResults.reduce( + (acc, current) => { + const existingDoc = acc.find((doc) => doc.id === current.id); + if (!existingDoc) { + acc.push(current); + } else if (current.similarity > existingDoc.similarity) { + // Replace if current match is better + const index = acc.findIndex((doc) => doc.id === current.id); + acc[index] = current; + } + return acc; + }, + [] as typeof contextualResults ); + data.appendMessageAnnotation(uniqueResults); + + if (lastUserMessage) { + lastUserMessage.content = + typeof lastUserMessage.content === "string" + ? lastUserMessage.content + + `<context>${JSON.stringify(uniqueResults)}</context>` + : [ + ...lastUserMessage.content, + { + type: "text", + text: `<context>${JSON.stringify(uniqueResults)}</context>`, + }, + ]; + coreMessages[coreMessages.length - 1] = lastUserMessage; + } + const result = await streamText({ model: googleClient, experimental_providerMetadata: { @@ -267,7 +309,7 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() role: "assistant", content: completion.text + - `<context>[${JSON.stringify(finalResults)}]</context>`, + `<context>[${JSON.stringify(uniqueResults)}]</context>`, }, ]; @@ -279,6 +321,8 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() } } catch (error) { console.error("Failed to update thread:", error); + } finally { + await data.close(); } }, }); @@ -510,32 +554,38 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() .from(spaceInDb) .where(eq(spaceInDb.uuid, spaceId)) .limit(1); - + if (space.length === 0) return null; return { id: space[0].id, ownerId: space[0].ownerId, - uuid: space[0].uuid + uuid: space[0].uuid, }; }) ); // Filter out any null values and check permissions - const validSpaces = spaceDetails.filter((s): s is NonNullable<typeof s> => s !== null); - const unauthorized = validSpaces.filter(s => s.ownerId !== user.id); + const validSpaces = spaceDetails.filter( + (s): s is NonNullable<typeof s> => s !== null + ); + const unauthorized = validSpaces.filter((s) => s.ownerId !== user.id); if (unauthorized.length > 0) { return c.json( { error: "Space permission denied", - details: unauthorized.map(s => s.uuid).join(", "), + details: unauthorized.map((s) => s.uuid).join(", "), }, 403 ); } // Replace UUIDs with IDs for the database query - spaces.splice(0, spaces.length, ...validSpaces.map(s => s.id.toString())); + spaces.splice( + 0, + spaces.length, + ...validSpaces.map((s) => s.id.toString()) + ); } try { @@ -553,28 +603,32 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() // Pre-compute the vector similarity expression to avoid multiple calculations const vectorSimilarity = sql<number>`1 - (embeddings <=> ${JSON.stringify(embeddings.data[0])}::vector)`; - const textSearchRank = sql<number>`ts_rank_cd(( - setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') || - setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') || - setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') || - setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D') - ), plainto_tsquery('english', ${query}))`; - + const textSearchRank = sql<number>`ts_rank_cd( + to_tsvector('english', coalesce(${chunk.textContent}, '')), + plainto_tsquery('english', ${query}) + )`; + + // First get the top matching chunks const results = await db .select({ - id: documents.id, - uuid: documents.uuid, - content: documents.content, - type: documents.type, - url: documents.url, - title: documents.title, - createdAt: documents.createdAt, - updatedAt: documents.updatedAt, - userId: documents.userId, - description: documents.description, - ogImage: documents.ogImage, + chunkId: chunk.id, + documentId: chunk.documentId, + textContent: chunk.textContent, + orderInDocument: chunk.orderInDocument, + metadata: chunk.metadata, similarity: vectorSimilarity, textRank: textSearchRank, + // Document fields + docUuid: documents.uuid, + docContent: documents.content, + docType: documents.type, + docUrl: documents.url, + docTitle: documents.title, + docCreatedAt: documents.createdAt, + docUpdatedAt: documents.updatedAt, + docUserId: documents.userId, + docDescription: documents.description, + docOgImage: documents.ogImage, }) .from(chunk) .innerJoin(documents, eq(chunk.documentId, documents.id)) @@ -611,12 +665,41 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() ) .limit(limit); - return c.json({ - results: results.map((r) => ({ - ...r, - similarity: Number(r.similarity.toFixed(4)), - })), - }); + // Group results by document and take the best matching chunk + const documentResults = new Map<number, (typeof results)[0]>(); + for (const result of results) { + const existingResult = documentResults.get(result.documentId); + if ( + !existingResult || + result.similarity > existingResult.similarity + ) { + documentResults.set(result.documentId, result); + } + } + + // Convert back to array and format response + const finalResults = Array.from(documentResults.values()).map((r) => ({ + id: r.documentId, + uuid: r.docUuid, + content: r.docContent, + type: r.docType, + url: r.docUrl, + title: r.docTitle, + createdAt: r.docCreatedAt, + updatedAt: r.docUpdatedAt, + userId: r.docUserId, + description: r.docDescription, + ogImage: r.docOgImage, + similarity: Number(r.similarity.toFixed(4)), + matchingChunk: { + id: r.chunkId, + content: r.textContent, + orderInDocument: r.orderInDocument, + metadata: r.metadata, + }, + })); + + return c.json({ results: finalResults }); } catch (error) { console.error("[Search Error]", error); return c.json( |