aboutsummaryrefslogtreecommitdiff
path: root/apps/backend/src
diff options
context:
space:
mode:
authorDhravya Shah <[email protected]>2025-01-27 21:02:43 -0700
committerDhravya Shah <[email protected]>2025-01-27 21:02:43 -0700
commitd2bf8d6623368f9ee1abb7d9add237fe6d003f1c (patch)
tree8f98c0793c62c177a45e6ad156b2bad872d6e474 /apps/backend/src
parentchange embedding model (diff)
downloadsupermemory-d2bf8d6623368f9ee1abb7d9add237fe6d003f1c.tar.xz
supermemory-d2bf8d6623368f9ee1abb7d9add237fe6d003f1c.zip
change embedding model
Diffstat (limited to 'apps/backend/src')
-rw-r--r--apps/backend/src/routes/actions.ts249
-rw-r--r--apps/backend/src/utils/fetchers.ts11
2 files changed, 59 insertions, 201 deletions
diff --git a/apps/backend/src/routes/actions.ts b/apps/backend/src/routes/actions.ts
index b9e6a4c2..1ac02fdd 100644
--- a/apps/backend/src/routes/actions.ts
+++ b/apps/backend/src/routes/actions.ts
@@ -54,8 +54,6 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
const { messages, threadId } = await c.req.valid("json");
- // TODO: add rate limiting
-
const unfilteredCoreMessages = convertToCoreMessages(
(messages as Message[])
.filter((m) => m.content.length > 0)
@@ -67,125 +65,66 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
? `<context>${JSON.stringify(m.annotations)}</context>`
: ""),
experimental_attachments:
- m.experimental_attachments &&
- m.experimental_attachments.length > 0
+ m.experimental_attachments?.length &&
+ m.experimental_attachments?.length > 0
? m.experimental_attachments
: (m.data as { files: [] })?.files,
}))
);
- // make sure that there is no empty messages. if there is, remove it.
const coreMessages = unfilteredCoreMessages.filter(
(message) => message.content.length > 0
);
- // .map(async (c) => {
- // if (
- // Array.isArray(c.content) &&
- // c.content.some((c) => c.type !== "text")
- // ) {
- // // convert attachments (IMAGE and files) to base64 by fetching them
- // const attachments = c.content.filter((c) => c.type !== "text");
- // const base64Attachments = await Promise.all(
- // attachments.map(async (a) => {
- // const type = (a as ImagePart | FilePart).type;
- // if (type === "image") {
- // const response = await fetch((a as ImagePart).image.toString());
- // return response.arrayBuffer();
- // } else if (type === "file") {
- // const response = await fetch((a as FilePart).data.toString());
- // return response.arrayBuffer();
- // }
- // })
- // );
- // }
- // });
-
- console.log("Core messages", JSON.stringify(coreMessages, null, 2));
-
- let threadUuid = threadId;
+ const db = database(c.env.HYPERDRIVE.connectionString);
const { initLogger, wrapAISDKModel } = await import("braintrust");
+ // Initialize clients and loggers
const logger = initLogger({
projectName: "supermemory",
apiKey: c.env.BRAINTRUST_API_KEY,
});
- // const gemini = createOpenAI({
- // apiKey: c.env.GEMINI_API_KEY,
- // baseURL: "https://generativelanguage.googleapis.com/v1beta/openai/",
- // });
- const openaiClient = openai(c.env);
-
- const googleClient = wrapAISDKModel(
- // google(c.env.GEMINI_API_KEY).chat("gemini-exp-1206")
- openai(c.env).chat("gpt-4o")
- );
-
- // Create new thread if none exists
- if (!threadUuid) {
- const uuid = randomId();
- const newThread = await database(c.env.HYPERDRIVE.connectionString)
- .insert(chatThreads)
- .values({
- firstMessage: messages[0].content,
- userId: user.id,
- uuid: uuid,
- messages: coreMessages,
- })
- .returning();
-
- threadUuid = newThread[0].uuid;
- }
- const openAi = openai(c.env);
-
-
-
- let lastUserMessage = coreMessages
- .reverse()
- .find((i) => i.role === "user");
+ const googleClient = wrapAISDKModel(openai(c.env).chat("gpt-4o"));
- // get the text of lastUserMEssage
+ // Get last user message and generate embedding in parallel with thread creation
+ let lastUserMessage = coreMessages.findLast((i) => i.role === "user");
const queryText =
typeof lastUserMessage?.content === "string"
? lastUserMessage.content
: lastUserMessage?.content.map((c) => (c as TextPart).text).join("");
- console.log("querytext", queryText);
-
- if (!queryText ||queryText.length === 0) {
- return c.json({ error: "Failed to generate embedding for query" }, 500);
+ if (!queryText || queryText.length === 0) {
+ return c.json({ error: "Empty query" }, 400);
}
- const embedStart = performance.now();
- const { data: embedding } = await c.env.AI.run("@cf/baai/bge-base-en-v1.5", {
- text: queryText,
- });
- const embedEnd = performance.now();
- console.log(`Embedding generation took ${embedEnd - embedStart}ms`);
+ // Run embedding generation and thread creation in parallel
+ const [{ data: embedding }, thread] = await Promise.all([
+ c.env.AI.run("@cf/baai/bge-base-en-v1.5", { text: queryText }),
+ !threadId
+ ? db
+ .insert(chatThreads)
+ .values({
+ firstMessage: messages[0].content,
+ userId: user.id,
+ uuid: randomId(),
+ messages: coreMessages,
+ })
+ .returning()
+ : null,
+ ]);
+
+ const threadUuid = threadId || thread?.[0].uuid;
if (!embedding) {
- return c.json({ error: "Failed to generate embedding for query" }, 500);
+ return c.json({ error: "Failed to generate embedding" }, 500);
}
- // Perform semantic search using cosine similarity
- // Log the query text to debug what we're searching for
- console.log("Searching for:", queryText);
- console.log("user id", user.id);
-
- const similarity = sql<number>`1 - (${cosineDistance(
- chunk.embeddings,
- embedding[0]
- )})`;
-
- // Find similar chunks using cosine similarity
- // Join with documents table to get chunks only from documents the user has access to
- // First get all results to normalize
- // Get top 20 results first to avoid processing entire dataset
- const dbQueryStart = performance.now();
- const topResults = await database(c.env.HYPERDRIVE.connectionString)
+ // Perform semantic search
+ const similarity = sql<number>`1 - (${cosineDistance(chunk.embeddings, embedding[0])})`;
+
+ const finalResults = await db
.select({
- similarity,
id: documents.id,
content: documents.content,
type: documents.type,
@@ -200,62 +139,10 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
.from(chunk)
.innerJoin(documents, eq(chunk.documentId, documents.id))
.where(and(eq(documents.userId, user.id), sql`${similarity} > 0.4`))
- .orderBy(desc(similarity));
-
- // Get unique documents with their highest similarity chunks
- const uniqueDocuments = Object.values(
- topResults.reduce(
- (acc, curr) => {
- if (
- !acc[curr.id] ||
- acc[curr.id].content === curr.content ||
- acc[curr.id].url === curr.url
- ) {
- acc[curr.id] = curr;
- }
- return acc;
- },
- {} as Record<number, (typeof topResults)[0]>
- )
- ).slice(0, 5);
-
- const dbQueryEnd = performance.now();
- console.log(`Database query took ${dbQueryEnd - dbQueryStart}ms`);
-
- // Calculate min/max once for the subset
- const processingStart = performance.now();
- const minSimilarity = Math.min(
- ...uniqueDocuments.map((r) => r.similarity)
- );
- const maxSimilarity = Math.max(
- ...uniqueDocuments.map((r) => r.similarity)
- );
- const range = maxSimilarity - minSimilarity;
-
- // Normalize the results
- const normalizedResults = uniqueDocuments.map((result) => ({
- ...result,
- normalizedSimilarity:
- range === 0 ? 1 : (result.similarity - minSimilarity) / range,
- }));
-
- // Get either all results above 0.6 threshold, or at least top 3 results
- const results = normalizedResults
- .sort((a, b) => b.normalizedSimilarity - a.normalizedSimilarity)
- .slice(
- 0,
- Math.max(
- 3,
- normalizedResults.filter((r) => r.normalizedSimilarity > 0.6).length
- )
- );
+ .orderBy(desc(similarity))
+ .limit(5);
- const processingEnd = performance.now();
- console.log(
- `Results processing took ${processingEnd - processingStart}ms`
- );
-
- const cleanDocumentsForContext = results.map((d) => ({
+ const cleanDocumentsForContext = finalResults.map((d) => ({
title: d.title,
description: d.description,
url: d.url,
@@ -263,8 +150,6 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
content: d.content,
}));
- // Update lastUserMessage with search results
- const messageUpdateStart = performance.now();
if (lastUserMessage) {
lastUserMessage.content =
typeof lastUserMessage.content === "string"
@@ -277,26 +162,30 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
text: `<context>${JSON.stringify(cleanDocumentsForContext)}</context>`,
},
];
- }
-
- // edit the last coreusermessage in the array
- if (lastUserMessage) {
coreMessages[coreMessages.length - 1] = lastUserMessage;
}
- const messageUpdateEnd = performance.now();
- console.log(
- `Message update took ${messageUpdateEnd - messageUpdateStart}ms`
- );
try {
- const streamStart = performance.now();
+ const data = new StreamData();
+ data.appendMessageAnnotation(
+ finalResults.map((r) => ({
+ id: r.id,
+ content: r.content,
+ type: r.type,
+ url: r.url,
+ title: r.title,
+ description: r.description,
+ ogImage: r.ogImage,
+ userId: r.userId,
+ createdAt: r.createdAt.toISOString(),
+ updatedAt: r.updatedAt?.toISOString() || null,
+ }))
+ );
+
const result = await streamText({
model: googleClient,
experimental_providerMetadata: {
- metadata: {
- userId: user.id,
- chatThreadId: threadUuid,
- },
+ metadata: { userId: user.id, chatThreadId: threadUuid ?? "" },
},
experimental_transform: smoothStream(),
messages: [
@@ -323,12 +212,11 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
],
async onFinish(completion) {
try {
- // remove context from lastUserMessage
if (lastUserMessage) {
lastUserMessage.content =
typeof lastUserMessage.content === "string"
? lastUserMessage.content.replace(
- /<context>([\s\S]*?)<\/context>/g,
+ /<context>[\s\S]*?<\/context>/g,
""
)
: lastUserMessage.content.filter(
@@ -338,60 +226,34 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
part.text.startsWith("<context>")
)
);
-
coreMessages[coreMessages.length - 1] = lastUserMessage;
}
- console.log("results", results);
-
const newMessages = [
...coreMessages,
{
role: "assistant",
content:
completion.text +
- `<context>[${JSON.stringify(results)}]</context>`,
+ `<context>[${JSON.stringify(finalResults)}]</context>`,
},
];
- await data.close();
if (threadUuid) {
- await database(c.env.HYPERDRIVE.connectionString)
+ await db
.update(chatThreads)
.set({ messages: newMessages })
.where(eq(chatThreads.uuid, threadUuid));
}
} catch (error) {
console.error("Failed to update thread:", error);
- // Continue execution - the message was delivered even if saving failed
}
},
});
- const streamEnd = performance.now();
- console.log(`Stream response took ${streamEnd - streamStart}ms`);
-
- const data = new StreamData();
-
- const context = results.map((r) => ({
- similarity: r.similarity,
- id: r.id,
- content: r.content,
- type: r.type,
- url: r.url,
- title: r.title,
- description: r.description,
- ogImage: r.ogImage,
- userId: r.userId,
- createdAt: r.createdAt.toISOString(),
- updatedAt: r.updatedAt?.toISOString() || null,
- }));
- // Full context objects in the data
- data.appendMessageAnnotation(context);
-
return result.toDataStreamResponse({
headers: {
- "Supermemory-Thread-Uuid": threadUuid,
+ "Supermemory-Thread-Uuid": threadUuid ?? "",
"Content-Type": "text/x-unknown",
"content-encoding": "identity",
"transfer-encoding": "chunked",
@@ -408,7 +270,6 @@ const actions = new Hono<{ Variables: Variables; Bindings: Env }>()
);
}
- // Handle database connection errors
if ((error as AISDKError).cause === "ECONNREFUSED") {
return c.json({ error: "Database connection failed" }, 503);
}
diff --git a/apps/backend/src/utils/fetchers.ts b/apps/backend/src/utils/fetchers.ts
index fbca3d4b..2329f48a 100644
--- a/apps/backend/src/utils/fetchers.ts
+++ b/apps/backend/src/utils/fetchers.ts
@@ -35,11 +35,6 @@ export const fetchContent = async (
tweetUrl.search = ""; // Remove all search params
const tweetId = tweetUrl.pathname.split("/").pop();
- const unrolledTweetContent = await step.do(
- "get unrolled tweet content",
- async () => await unrollTweets(tweetUrl.toString())
- );
-
const rawBaseTweetContent = await step.do(
"extract tweet content",
async () => {
@@ -72,8 +67,10 @@ export const fetchContent = async (
};
raw: string;
};
-
- if (!unrolledTweetContent || isErr(unrolledTweetContent)) {
+ const unrolledTweetContent = {
+ value: [rawBaseTweetContent],
+ };
+ if (true) {
console.error("Can't get thread, reverting back to single tweet");
tweetContent = {
text: rawBaseTweetContent.text,