aboutsummaryrefslogtreecommitdiff
path: root/apps/backend/src
diff options
context:
space:
mode:
authorDhravya Shah <[email protected]>2025-02-18 21:20:15 -0700
committerDhravya Shah <[email protected]>2025-02-18 21:20:15 -0700
commit6cfc234cc059f0aa3f9e47d01bff5965a908a8a1 (patch)
tree4e00c1a9aef015f80e00542537a7c66f98740812 /apps/backend/src
parentimplement hybrid search (diff)
downloadarchived-supermemory-6cfc234cc059f0aa3f9e47d01bff5965a908a8a1.tar.xz
archived-supermemory-6cfc234cc059f0aa3f9e47d01bff5965a908a8a1.zip
implemented proper hybrid search with date relevancy into consideration
Diffstat (limited to 'apps/backend/src')
-rw-r--r--apps/backend/src/routes/actions.ts185
1 files changed, 71 insertions, 114 deletions
diff --git a/apps/backend/src/routes/actions.ts b/apps/backend/src/routes/actions.ts
index d723bf2e..0bc26052 100644
--- a/apps/backend/src/routes/actions.ts
+++ b/apps/backend/src/routes/actions.ts
@@ -89,8 +89,7 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
});
const googleClient = wrapAISDKModel(
- openai(c.env).chat("gpt-4o-mini-2024-07-18")
- );
+ openai(c.env).chat("gpt-4o-mini-2024-07-18"));
// Get last user message and generate embedding in parallel with thread creation
let lastUserMessage = coreMessages.findLast((i) => i.role === "user");
@@ -125,7 +124,15 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
return c.json({ error: "Failed to generate embedding" }, 500);
}
- // Perform hybrid search for context retrieval
+ // Pre-compute the vector similarity expression to avoid multiple calculations
+ const vectorSimilarity = sql<number>`1 - (embeddings <=> ${JSON.stringify(embedding[0])}::vector)`;
+ const textSearchRank = sql<number>`ts_rank_cd((
+ setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') ||
+ setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') ||
+ setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') ||
+ setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D')
+ ), plainto_tsquery('english', ${queryText}))`;
+
const finalResults = await db
.select({
id: documents.id,
@@ -138,43 +145,25 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
userId: documents.userId,
description: documents.description,
ogImage: documents.ogImage,
- vectorSimilarity: sql<number>`1 - (embeddings <=> ${JSON.stringify(embedding[0])}::vector)`,
- textSimilarity: sql<number>`ts_rank((
- setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') ||
- setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') ||
- setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') ||
- setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D')
- ), plainto_tsquery('english', ${queryText}))`,
- hybridScore: sql<number>`(
- 0.75 * (1 - (embeddings <=> ${JSON.stringify(embedding[0])}::vector)) +
- 0.25 * ts_rank((
- setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') ||
- setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') ||
- setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') ||
- setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D')
- ), plainto_tsquery('english', ${queryText}))
- )::float`,
+ similarity: vectorSimilarity,
+ textRank: textSearchRank,
})
.from(chunk)
.innerJoin(documents, eq(chunk.documentId, documents.id))
.where(
and(
eq(documents.userId, user.id),
- sql`1 - (embeddings <=> ${JSON.stringify(embedding[0])}::vector) > 0.4`
+ sql`${vectorSimilarity} > 0.5`
)
)
.orderBy(
desc(sql<number>`(
- 0.75 * (1 - (embeddings <=> ${JSON.stringify(embedding[0])}::vector)) +
- 0.25 * ts_rank((
- setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') ||
- setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') ||
- setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') ||
- setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D')
- ), plainto_tsquery('english', ${queryText}))
- )::float`)
+ 0.6 * ${vectorSimilarity} +
+ 0.25 * ${textSearchRank} +
+ 0.15 * (1.0 / (1.0 + extract(epoch from age(${documents.updatedAt})) / (90 * 24 * 60 * 60)))
+ )::float`)
)
- .limit(5);
+ .limit(15);
const cleanDocumentsForContext = finalResults.map((d) => ({
title: d.title,
@@ -202,37 +191,27 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
try {
const data = new StreamData();
// De-duplicate chunks by URL to avoid showing duplicate content
- const uniqueResults = finalResults.reduce(
- (acc, current) => {
- const existingResult = acc.find((item) => item.id === current.id);
- if (!existingResult) {
- acc.push(current);
- }
- return acc;
- },
- [] as typeof finalResults
- );
+ const uniqueResults = finalResults.reduce((acc, current) => {
+ const existingResult = acc.find(item => item.id === current.id);
+ if (!existingResult) {
+ acc.push(current);
+ }
+ return acc;
+ }, [] as typeof finalResults);
data.appendMessageAnnotation(
- uniqueResults.map(
- (r) =>
- ({
- id: String(r.id),
- content: String(r.content || ""),
- type: String(r.type || ""),
- url: String(r.url || ""),
- title: String(r.title || ""),
- description: String(r.description || ""),
- ogImage: String(r.ogImage || ""),
- userId: String(r.userId),
- createdAt:
- r.createdAt instanceof Date ? r.createdAt.toISOString() : "",
- updatedAt:
- r.updatedAt instanceof Date
- ? r.updatedAt.toISOString()
- : null,
- }) as const
- )
+ uniqueResults.map((r) => ({
+ id: r.id,
+ content: r.content,
+ type: r.type,
+ url: r.url,
+ title: r.title,
+ description: r.description,
+ ogImage: r.ogImage,
+ userId: r.userId,
+ createdAt: r.createdAt.toISOString(),
+ updatedAt: r.updatedAt?.toISOString() || null,
+ }))
);
const result = await streamText({
@@ -511,22 +490,10 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
limit: z.number().min(1).max(50).default(10),
threshold: z.number().min(0).max(1).default(0),
spaces: z.array(z.string()).optional(),
- weights: z
- .object({
- semantic: z.number().min(0).max(1).default(0.75),
- keyword: z.number().min(0).max(1).default(0.25),
- })
- .optional(),
})
),
async (c) => {
- const {
- query,
- limit,
- threshold,
- spaces,
- weights = { semantic: 0.75, keyword: 0.25 },
- } = c.req.valid("json");
+ const { query, limit, threshold, spaces } = c.req.valid("json");
const user = c.get("user");
if (!user) {
@@ -543,36 +510,32 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
.from(spaceInDb)
.where(eq(spaceInDb.uuid, spaceId))
.limit(1);
-
+
if (space.length === 0) return null;
return {
id: space[0].id,
ownerId: space[0].ownerId,
- uuid: space[0].uuid,
+ uuid: space[0].uuid
};
})
);
- const validSpaces = spaceDetails.filter(
- (s): s is NonNullable<typeof s> => s !== null
- );
- const unauthorized = validSpaces.filter((s) => s.ownerId !== user.id);
+ // Filter out any null values and check permissions
+ const validSpaces = spaceDetails.filter((s): s is NonNullable<typeof s> => s !== null);
+ const unauthorized = validSpaces.filter(s => s.ownerId !== user.id);
if (unauthorized.length > 0) {
return c.json(
{
error: "Space permission denied",
- details: unauthorized.map((s) => s.uuid).join(", "),
+ details: unauthorized.map(s => s.uuid).join(", "),
},
403
);
}
- spaces.splice(
- 0,
- spaces.length,
- ...validSpaces.map((s) => s.id.toString())
- );
+ // Replace UUIDs with IDs for the database query
+ spaces.splice(0, spaces.length, ...validSpaces.map(s => s.id.toString()));
}
try {
@@ -588,37 +551,37 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
);
}
- // Perform hybrid search using both vector similarity and full-text search
- const results = await database(c.env.HYPERDRIVE.connectionString)
+ // Pre-compute the vector similarity expression to avoid multiple calculations
+ const vectorSimilarity = sql<number>`1 - (embeddings <=> ${JSON.stringify(embeddings.data[0])}::vector)`;
+ const textSearchRank = sql<number>`ts_rank_cd((
+ setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') ||
+ setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') ||
+ setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') ||
+ setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D')
+ ), plainto_tsquery('english', ${query}))`;
+
+ const results = await db
.select({
id: documents.id,
uuid: documents.uuid,
content: documents.content,
+ type: documents.type,
+ url: documents.url,
+ title: documents.title,
createdAt: documents.createdAt,
- chunkContent: chunk.textContent,
- vectorSimilarity: sql<number>`1 - (embeddings <=> ${JSON.stringify(embeddings.data[0])}::vector)`,
- textSimilarity: sql<number>`ts_rank((
- setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') ||
- setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') ||
- setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') ||
- setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D')
- ), plainto_tsquery('english', ${query}))`,
- hybridScore: sql<number>`(
- ${weights.semantic} * (1 - (embeddings <=> ${JSON.stringify(embeddings.data[0])}::vector)) +
- ${weights.keyword} * ts_rank((
- setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') ||
- setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') ||
- setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') ||
- setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D')
- ), plainto_tsquery('english', ${query}))
- )::float`,
+ updatedAt: documents.updatedAt,
+ userId: documents.userId,
+ description: documents.description,
+ ogImage: documents.ogImage,
+ similarity: vectorSimilarity,
+ textRank: textSearchRank,
})
.from(chunk)
.innerJoin(documents, eq(chunk.documentId, documents.id))
.where(
and(
eq(documents.userId, user.id),
- sql`1 - (embeddings <=> ${JSON.stringify(embeddings.data[0])}::vector) >= ${threshold}`,
+ sql`${vectorSimilarity} > ${threshold}`,
...(spaces && spaces.length > 0
? [
exists(
@@ -641,23 +604,17 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
)
.orderBy(
desc(sql<number>`(
- ${weights.semantic} * (1 - (embeddings <=> ${JSON.stringify(embeddings.data[0])}::vector)) +
- ${weights.keyword} * ts_rank((
- setweight(to_tsvector('english', coalesce(${documents.content}, '')),'A') ||
- setweight(to_tsvector('english', coalesce(${documents.title}, '')),'B') ||
- setweight(to_tsvector('english', coalesce(${documents.description}, '')),'C') ||
- setweight(to_tsvector('english', coalesce(${documents.url}, '')),'D')
- ), plainto_tsquery('english', ${query}))
- )::float`)
+ 0.6 * ${vectorSimilarity} +
+ 0.25 * ${textSearchRank} +
+ 0.15 * (1.0 / (1.0 + extract(epoch from age(${documents.updatedAt})) / (90 * 24 * 60 * 60)))
+ )::float`)
)
.limit(limit);
return c.json({
results: results.map((r) => ({
...r,
- vectorSimilarity: Number(r.vectorSimilarity.toFixed(4)),
- textSimilarity: Number(r.textSimilarity.toFixed(4)),
- hybridScore: Number(r.hybridScore.toFixed(4)),
+ similarity: Number(r.similarity.toFixed(4)),
})),
});
} catch (error) {