diff options
| author | Dhravya Shah <[email protected]> | 2025-01-27 21:02:43 -0700 |
|---|---|---|
| committer | Dhravya Shah <[email protected]> | 2025-01-27 21:02:43 -0700 |
| commit | d2bf8d6623368f9ee1abb7d9add237fe6d003f1c (patch) | |
| tree | 8f98c0793c62c177a45e6ad156b2bad872d6e474 /apps/backend/src | |
| parent | change embedding model (diff) | |
| download | supermemory-d2bf8d6623368f9ee1abb7d9add237fe6d003f1c.tar.xz supermemory-d2bf8d6623368f9ee1abb7d9add237fe6d003f1c.zip | |
change embedding model
Diffstat (limited to 'apps/backend/src')
| -rw-r--r-- | apps/backend/src/routes/actions.ts | 249 | ||||
| -rw-r--r-- | apps/backend/src/utils/fetchers.ts | 11 |
2 files changed, 59 insertions, 201 deletions
diff --git a/apps/backend/src/routes/actions.ts b/apps/backend/src/routes/actions.ts index b9e6a4c2..1ac02fdd 100644 --- a/apps/backend/src/routes/actions.ts +++ b/apps/backend/src/routes/actions.ts @@ -54,8 +54,6 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() const { messages, threadId } = await c.req.valid("json"); - // TODO: add rate limiting - const unfilteredCoreMessages = convertToCoreMessages( (messages as Message[]) .filter((m) => m.content.length > 0) @@ -67,125 +65,66 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() ? `<context>${JSON.stringify(m.annotations)}</context>` : ""), experimental_attachments: - m.experimental_attachments && - m.experimental_attachments.length > 0 + m.experimental_attachments?.length && + m.experimental_attachments?.length > 0 ? m.experimental_attachments : (m.data as { files: [] })?.files, })) ); - // make sure that there is no empty messages. if there is, remove it. const coreMessages = unfilteredCoreMessages.filter( (message) => message.content.length > 0 ); - // .map(async (c) => { - // if ( - // Array.isArray(c.content) && - // c.content.some((c) => c.type !== "text") - // ) { - // // convert attachments (IMAGE and files) to base64 by fetching them - // const attachments = c.content.filter((c) => c.type !== "text"); - // const base64Attachments = await Promise.all( - // attachments.map(async (a) => { - // const type = (a as ImagePart | FilePart).type; - // if (type === "image") { - // const response = await fetch((a as ImagePart).image.toString()); - // return response.arrayBuffer(); - // } else if (type === "file") { - // const response = await fetch((a as FilePart).data.toString()); - // return response.arrayBuffer(); - // } - // }) - // ); - // } - // }); - - console.log("Core messages", JSON.stringify(coreMessages, null, 2)); - - let threadUuid = threadId; + const db = database(c.env.HYPERDRIVE.connectionString); const { initLogger, wrapAISDKModel } = await import("braintrust"); + // Initialize clients and loggers const logger = initLogger({ projectName: "supermemory", apiKey: c.env.BRAINTRUST_API_KEY, }); - // const gemini = createOpenAI({ - // apiKey: c.env.GEMINI_API_KEY, - // baseURL: "https://generativelanguage.googleapis.com/v1beta/openai/", - // }); - const openaiClient = openai(c.env); - - const googleClient = wrapAISDKModel( - // google(c.env.GEMINI_API_KEY).chat("gemini-exp-1206") - openai(c.env).chat("gpt-4o") - ); - - // Create new thread if none exists - if (!threadUuid) { - const uuid = randomId(); - const newThread = await database(c.env.HYPERDRIVE.connectionString) - .insert(chatThreads) - .values({ - firstMessage: messages[0].content, - userId: user.id, - uuid: uuid, - messages: coreMessages, - }) - .returning(); - - threadUuid = newThread[0].uuid; - } - const openAi = openai(c.env); - - - - let lastUserMessage = coreMessages - .reverse() - .find((i) => i.role === "user"); + const googleClient = wrapAISDKModel(openai(c.env).chat("gpt-4o")); - // get the text of lastUserMEssage + // Get last user message and generate embedding in parallel with thread creation + let lastUserMessage = coreMessages.findLast((i) => i.role === "user"); const queryText = typeof lastUserMessage?.content === "string" ? lastUserMessage.content : lastUserMessage?.content.map((c) => (c as TextPart).text).join(""); - console.log("querytext", queryText); - - if (!queryText ||queryText.length === 0) { - return c.json({ error: "Failed to generate embedding for query" }, 500); + if (!queryText || queryText.length === 0) { + return c.json({ error: "Empty query" }, 400); } - const embedStart = performance.now(); - const { data: embedding } = await c.env.AI.run("@cf/baai/bge-base-en-v1.5", { - text: queryText, - }); - const embedEnd = performance.now(); - console.log(`Embedding generation took ${embedEnd - embedStart}ms`); + // Run embedding generation and thread creation in parallel + const [{ data: embedding }, thread] = await Promise.all([ + c.env.AI.run("@cf/baai/bge-base-en-v1.5", { text: queryText }), + !threadId + ? db + .insert(chatThreads) + .values({ + firstMessage: messages[0].content, + userId: user.id, + uuid: randomId(), + messages: coreMessages, + }) + .returning() + : null, + ]); + + const threadUuid = threadId || thread?.[0].uuid; if (!embedding) { - return c.json({ error: "Failed to generate embedding for query" }, 500); + return c.json({ error: "Failed to generate embedding" }, 500); } - // Perform semantic search using cosine similarity - // Log the query text to debug what we're searching for - console.log("Searching for:", queryText); - console.log("user id", user.id); - - const similarity = sql<number>`1 - (${cosineDistance( - chunk.embeddings, - embedding[0] - )})`; - - // Find similar chunks using cosine similarity - // Join with documents table to get chunks only from documents the user has access to - // First get all results to normalize - // Get top 20 results first to avoid processing entire dataset - const dbQueryStart = performance.now(); - const topResults = await database(c.env.HYPERDRIVE.connectionString) + // Perform semantic search + const similarity = sql<number>`1 - (${cosineDistance(chunk.embeddings, embedding[0])})`; + + const finalResults = await db .select({ - similarity, id: documents.id, content: documents.content, type: documents.type, @@ -200,62 +139,10 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() .from(chunk) .innerJoin(documents, eq(chunk.documentId, documents.id)) .where(and(eq(documents.userId, user.id), sql`${similarity} > 0.4`)) - .orderBy(desc(similarity)); - - // Get unique documents with their highest similarity chunks - const uniqueDocuments = Object.values( - topResults.reduce( - (acc, curr) => { - if ( - !acc[curr.id] || - acc[curr.id].content === curr.content || - acc[curr.id].url === curr.url - ) { - acc[curr.id] = curr; - } - return acc; - }, - {} as Record<number, (typeof topResults)[0]> - ) - ).slice(0, 5); - - const dbQueryEnd = performance.now(); - console.log(`Database query took ${dbQueryEnd - dbQueryStart}ms`); - - // Calculate min/max once for the subset - const processingStart = performance.now(); - const minSimilarity = Math.min( - ...uniqueDocuments.map((r) => r.similarity) - ); - const maxSimilarity = Math.max( - ...uniqueDocuments.map((r) => r.similarity) - ); - const range = maxSimilarity - minSimilarity; - - // Normalize the results - const normalizedResults = uniqueDocuments.map((result) => ({ - ...result, - normalizedSimilarity: - range === 0 ? 1 : (result.similarity - minSimilarity) / range, - })); - - // Get either all results above 0.6 threshold, or at least top 3 results - const results = normalizedResults - .sort((a, b) => b.normalizedSimilarity - a.normalizedSimilarity) - .slice( - 0, - Math.max( - 3, - normalizedResults.filter((r) => r.normalizedSimilarity > 0.6).length - ) - ); + .orderBy(desc(similarity)) + .limit(5); - const processingEnd = performance.now(); - console.log( - `Results processing took ${processingEnd - processingStart}ms` - ); - - const cleanDocumentsForContext = results.map((d) => ({ + const cleanDocumentsForContext = finalResults.map((d) => ({ title: d.title, description: d.description, url: d.url, @@ -263,8 +150,6 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() content: d.content, })); - // Update lastUserMessage with search results - const messageUpdateStart = performance.now(); if (lastUserMessage) { lastUserMessage.content = typeof lastUserMessage.content === "string" @@ -277,26 +162,30 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() text: `<context>${JSON.stringify(cleanDocumentsForContext)}</context>`, }, ]; - } - - // edit the last coreusermessage in the array - if (lastUserMessage) { coreMessages[coreMessages.length - 1] = lastUserMessage; } - const messageUpdateEnd = performance.now(); - console.log( - `Message update took ${messageUpdateEnd - messageUpdateStart}ms` - ); try { - const streamStart = performance.now(); + const data = new StreamData(); + data.appendMessageAnnotation( + finalResults.map((r) => ({ + id: r.id, + content: r.content, + type: r.type, + url: r.url, + title: r.title, + description: r.description, + ogImage: r.ogImage, + userId: r.userId, + createdAt: r.createdAt.toISOString(), + updatedAt: r.updatedAt?.toISOString() || null, + })) + ); + const result = await streamText({ model: googleClient, experimental_providerMetadata: { - metadata: { - userId: user.id, - chatThreadId: threadUuid, - }, + metadata: { userId: user.id, chatThreadId: threadUuid ?? "" }, }, experimental_transform: smoothStream(), messages: [ @@ -323,12 +212,11 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() ], async onFinish(completion) { try { - // remove context from lastUserMessage if (lastUserMessage) { lastUserMessage.content = typeof lastUserMessage.content === "string" ? lastUserMessage.content.replace( - /<context>([\s\S]*?)<\/context>/g, + /<context>[\s\S]*?<\/context>/g, "" ) : lastUserMessage.content.filter( @@ -338,60 +226,34 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() part.text.startsWith("<context>") ) ); - coreMessages[coreMessages.length - 1] = lastUserMessage; } - console.log("results", results); - const newMessages = [ ...coreMessages, { role: "assistant", content: completion.text + - `<context>[${JSON.stringify(results)}]</context>`, + `<context>[${JSON.stringify(finalResults)}]</context>`, }, ]; - await data.close(); if (threadUuid) { - await database(c.env.HYPERDRIVE.connectionString) + await db .update(chatThreads) .set({ messages: newMessages }) .where(eq(chatThreads.uuid, threadUuid)); } } catch (error) { console.error("Failed to update thread:", error); - // Continue execution - the message was delivered even if saving failed } }, }); - const streamEnd = performance.now(); - console.log(`Stream response took ${streamEnd - streamStart}ms`); - - const data = new StreamData(); - - const context = results.map((r) => ({ - similarity: r.similarity, - id: r.id, - content: r.content, - type: r.type, - url: r.url, - title: r.title, - description: r.description, - ogImage: r.ogImage, - userId: r.userId, - createdAt: r.createdAt.toISOString(), - updatedAt: r.updatedAt?.toISOString() || null, - })); - // Full context objects in the data - data.appendMessageAnnotation(context); - return result.toDataStreamResponse({ headers: { - "Supermemory-Thread-Uuid": threadUuid, + "Supermemory-Thread-Uuid": threadUuid ?? "", "Content-Type": "text/x-unknown", "content-encoding": "identity", "transfer-encoding": "chunked", @@ -408,7 +270,6 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>() ); } - // Handle database connection errors if ((error as AISDKError).cause === "ECONNREFUSED") { return c.json({ error: "Database connection failed" }, 503); } diff --git a/apps/backend/src/utils/fetchers.ts b/apps/backend/src/utils/fetchers.ts index fbca3d4b..2329f48a 100644 --- a/apps/backend/src/utils/fetchers.ts +++ b/apps/backend/src/utils/fetchers.ts @@ -35,11 +35,6 @@ export const fetchContent = async ( tweetUrl.search = ""; // Remove all search params const tweetId = tweetUrl.pathname.split("/").pop(); - const unrolledTweetContent = await step.do( - "get unrolled tweet content", - async () => await unrollTweets(tweetUrl.toString()) - ); - const rawBaseTweetContent = await step.do( "extract tweet content", async () => { @@ -72,8 +67,10 @@ export const fetchContent = async ( }; raw: string; }; - - if (!unrolledTweetContent || isErr(unrolledTweetContent)) { + const unrolledTweetContent = { + value: [rawBaseTweetContent], + }; + if (true) { console.error("Can't get thread, reverting back to single tweet"); tweetContent = { text: rawBaseTweetContent.text, |