diff options
| author | Dhravya Shah <[email protected]> | 2025-02-18 21:20:15 -0700 |
|---|---|---|
| committer | Dhravya Shah <[email protected]> | 2025-02-18 21:20:15 -0700 |
| commit | 6cfc234cc059f0aa3f9e47d01bff5965a908a8a1 (patch) | |
| tree | 4e00c1a9aef015f80e00542537a7c66f98740812 /apps/backend/src | |
| parent | implement hybrid search (diff) | |
| download | archived-supermemory-6cfc234cc059f0aa3f9e47d01bff5965a908a8a1.tar.xz archived-supermemory-6cfc234cc059f0aa3f9e47d01bff5965a908a8a1.zip | |
implemented proper hybrid search with date relevancy into consideration
Diffstat (limited to 'apps/backend/src')
| -rw-r--r-- | apps/backend/src/routes/actions.ts | 185 |
1 files changed, 71 insertions, 114 deletions
diff --git a/apps/backend/src/routes/actions.ts b/apps/backend/src/routes/actions.ts index d723bf2e..0bc26052 100644 --- a/apps/backend/src/routes/actions.ts +++ b/apps/backend/src/routes/actions.ts @@ -89,8 +89,7 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() }); const googleClient = wrapAISDKModel( - openai(c.env).chat("gpt-4o-mini-2024-07-18") - ); + openai(c.env).chat("gpt-4o-mini-2024-07-18")); // Get last user message and generate embedding in parallel with thread creation let lastUserMessage = coreMessages.findLast((i) => i.role === "user"); @@ -125,7 +124,15 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() return c.json({ error: "Failed to generate embedding" }, 500); } - // Perform hybrid search for context retrieval + // Pre-compute the vector similarity expression to avoid multiple calculations + const vectorSimilarity = sql<number>`1 - (embeddings <=> ${JSON.stringify(embedding[0])}::vector)`; + const textSearchRank = sql<number>`ts_rank_cd(( + setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') || + setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') || + setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') || + setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D') + ), plainto_tsquery('english', ${queryText}))`; + const finalResults = await db .select({ id: documents.id, @@ -138,43 +145,25 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() userId: documents.userId, description: documents.description, ogImage: documents.ogImage, - vectorSimilarity: sql<number>`1 - (embeddings <=> ${JSON.stringify(embedding[0])}::vector)`, - textSimilarity: sql<number>`ts_rank(( - setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') || - setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') || - setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') || - setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D') - ), plainto_tsquery('english', ${queryText}))`, - hybridScore: sql<number>`( - 0.75 * (1 - (embeddings <=> ${JSON.stringify(embedding[0])}::vector)) + - 0.25 * ts_rank(( - setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') || - setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') || - setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') || - setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D') - ), plainto_tsquery('english', ${queryText})) - )::float`, + similarity: vectorSimilarity, + textRank: textSearchRank, }) .from(chunk) .innerJoin(documents, eq(chunk.documentId, documents.id)) .where( and( eq(documents.userId, user.id), - sql`1 - (embeddings <=> ${JSON.stringify(embedding[0])}::vector) > 0.4` + sql`${vectorSimilarity} > 0.5` ) ) .orderBy( desc(sql<number>`( - 0.75 * (1 - (embeddings <=> ${JSON.stringify(embedding[0])}::vector)) + - 0.25 * ts_rank(( - setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') || - setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') || - setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') || - setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D') - ), plainto_tsquery('english', ${queryText})) - )::float`) + 0.6 * ${vectorSimilarity} + + 0.25 * ${textSearchRank} + + 0.15 * (1.0 / (1.0 + extract(epoch from age(${documents.updatedAt})) / (90 * 24 * 60 * 60))) + )::float`) ) - .limit(5); + .limit(15); const cleanDocumentsForContext = finalResults.map((d) => ({ title: d.title, @@ -202,37 +191,27 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() try { const data = new StreamData(); // De-duplicate chunks by URL to avoid showing duplicate content - const uniqueResults = finalResults.reduce( - (acc, current) => { - const existingResult = acc.find((item) => item.id === current.id); - if (!existingResult) { - acc.push(current); - } - return acc; - }, - [] as typeof finalResults - ); + const uniqueResults = finalResults.reduce((acc, current) => { + const existingResult = acc.find(item => item.id === current.id); + if (!existingResult) { + acc.push(current); + } + return acc; + }, [] as typeof finalResults); data.appendMessageAnnotation( - uniqueResults.map( - (r) => - ({ - id: String(r.id), - content: String(r.content || ""), - type: String(r.type || ""), - url: String(r.url || ""), - title: String(r.title || ""), - description: String(r.description || ""), - ogImage: String(r.ogImage || ""), - userId: String(r.userId), - createdAt: - r.createdAt instanceof Date ? r.createdAt.toISOString() : "", - updatedAt: - r.updatedAt instanceof Date - ? r.updatedAt.toISOString() - : null, - }) as const - ) + uniqueResults.map((r) => ({ + id: r.id, + content: r.content, + type: r.type, + url: r.url, + title: r.title, + description: r.description, + ogImage: r.ogImage, + userId: r.userId, + createdAt: r.createdAt.toISOString(), + updatedAt: r.updatedAt?.toISOString() || null, + })) ); const result = await streamText({ @@ -511,22 +490,10 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() limit: z.number().min(1).max(50).default(10), threshold: z.number().min(0).max(1).default(0), spaces: z.array(z.string()).optional(), - weights: z - .object({ - semantic: z.number().min(0).max(1).default(0.75), - keyword: z.number().min(0).max(1).default(0.25), - }) - .optional(), }) ), async (c) => { - const { - query, - limit, - threshold, - spaces, - weights = { semantic: 0.75, keyword: 0.25 }, - } = c.req.valid("json"); + const { query, limit, threshold, spaces } = c.req.valid("json"); const user = c.get("user"); if (!user) { @@ -543,36 +510,32 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() .from(spaceInDb) .where(eq(spaceInDb.uuid, spaceId)) .limit(1); - + if (space.length === 0) return null; return { id: space[0].id, ownerId: space[0].ownerId, - uuid: space[0].uuid, + uuid: space[0].uuid }; }) ); - const validSpaces = spaceDetails.filter( - (s): s is NonNullable<typeof s> => s !== null - ); - const unauthorized = validSpaces.filter((s) => s.ownerId !== user.id); + // Filter out any null values and check permissions + const validSpaces = spaceDetails.filter((s): s is NonNullable<typeof s> => s !== null); + const unauthorized = validSpaces.filter(s => s.ownerId !== user.id); if (unauthorized.length > 0) { return c.json( { error: "Space permission denied", - details: unauthorized.map((s) => s.uuid).join(", "), + details: unauthorized.map(s => s.uuid).join(", "), }, 403 ); } - spaces.splice( - 0, - spaces.length, - ...validSpaces.map((s) => s.id.toString()) - ); + // Replace UUIDs with IDs for the database query + spaces.splice(0, spaces.length, ...validSpaces.map(s => s.id.toString())); } try { @@ -588,37 +551,37 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() ); } - // Perform hybrid search using both vector similarity and full-text search - const results = await database(c.env.HYPERDRIVE.connectionString) + // Pre-compute the vector similarity expression to avoid multiple calculations + const vectorSimilarity = sql<number>`1 - (embeddings <=> ${JSON.stringify(embeddings.data[0])}::vector)`; + const textSearchRank = sql<number>`ts_rank_cd(( + setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') || + setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') || + setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') || + setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D') + ), plainto_tsquery('english', ${query}))`; + + const results = await db .select({ id: documents.id, uuid: documents.uuid, content: documents.content, + type: documents.type, + url: documents.url, + title: documents.title, createdAt: documents.createdAt, - chunkContent: chunk.textContent, - vectorSimilarity: sql<number>`1 - (embeddings <=> ${JSON.stringify(embeddings.data[0])}::vector)`, - textSimilarity: sql<number>`ts_rank(( - setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') || - setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') || - setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') || - setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D') - ), plainto_tsquery('english', ${query}))`, - hybridScore: sql<number>`( - ${weights.semantic} * (1 - (embeddings <=> ${JSON.stringify(embeddings.data[0])}::vector)) + - ${weights.keyword} * ts_rank(( - setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') || - setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') || - setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') || - setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D') - ), plainto_tsquery('english', ${query})) - )::float`, + updatedAt: documents.updatedAt, + userId: documents.userId, + description: documents.description, + ogImage: documents.ogImage, + similarity: vectorSimilarity, + textRank: textSearchRank, }) .from(chunk) .innerJoin(documents, eq(chunk.documentId, documents.id)) .where( and( eq(documents.userId, user.id), - sql`1 - (embeddings <=> ${JSON.stringify(embeddings.data[0])}::vector) >= ${threshold}`, + sql`${vectorSimilarity} > ${threshold}`, ...(spaces && spaces.length > 0 ? [ exists( @@ -641,23 +604,17 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() ) .orderBy( desc(sql<number>`( - ${weights.semantic} * (1 - (embeddings <=> ${JSON.stringify(embeddings.data[0])}::vector)) + - ${weights.keyword} * ts_rank(( - setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') || - setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') || - setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') || - setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D') - ), plainto_tsquery('english', ${query})) - )::float`) + 0.6 * ${vectorSimilarity} + + 0.25 * ${textSearchRank} + + 0.15 * (1.0 / (1.0 + extract(epoch from age(${documents.updatedAt})) / (90 * 24 * 60 * 60))) + )::float`) ) .limit(limit); return c.json({ results: results.map((r) => ({ ...r, - vectorSimilarity: Number(r.vectorSimilarity.toFixed(4)), - textSimilarity: Number(r.textSimilarity.toFixed(4)), - hybridScore: Number(r.hybridScore.toFixed(4)), + similarity: Number(r.similarity.toFixed(4)), })), }); } catch (error) { |