diff options
| author | Dhravya <[email protected]> | 2024-06-23 17:54:05 -0500 |
|---|---|---|
| committer | Dhravya <[email protected]> | 2024-06-23 17:54:05 -0500 |
| commit | 9f751ba89cc37516daa261bbae8723d5566a6602 (patch) | |
| tree | 3e7533aefbf7931c966c930349263ae248cc1135 | |
| parent | added indexes and stuff (diff) | |
| download | supermemory-9f751ba89cc37516daa261bbae8723d5566a6602.tar.xz supermemory-9f751ba89cc37516daa261bbae8723d5566a6602.zip | |
feat: vector lookup and chat is twice as fast now
| -rw-r--r-- | apps/cf-ai-backend/src/index.ts | 184 | ||||
| -rw-r--r-- | apps/cf-ai-backend/src/types.ts | 2 | ||||
| -rw-r--r-- | apps/web/app/(dash)/chat/chatWindow.tsx | 17 | ||||
| -rw-r--r-- | apps/web/app/api/chat/route.ts | 24 | ||||
| -rw-r--r-- | packages/shared-types/index.ts | 8 |
5 files changed, 130 insertions, 105 deletions
diff --git a/apps/cf-ai-backend/src/index.ts b/apps/cf-ai-backend/src/index.ts index e89d170c..224f2a42 100644 --- a/apps/cf-ai-backend/src/index.ts +++ b/apps/cf-ai-backend/src/index.ts @@ -1,6 +1,6 @@ import { z } from "zod"; import { Hono } from "hono"; -import { CoreMessage, generateText, streamText } from "ai"; +import { CoreMessage, streamText } from "ai"; import { chatObj, Env, vectorObj } from "./types"; import { batchCreateChunksAndEmbeddings, @@ -193,90 +193,102 @@ app.post( const body = c.req.valid("json"); const sourcesOnly = query.sourcesOnly === "true"; + + // Return early for dumb requests + if (sourcesOnly && body.sources) { + return c.json(body.sources); + } + const spaces = query.spaces?.split(",") ?? [undefined]; // Get the AI model maker and vector store const { model, store } = await initQuery(c, query.model); - const filter: VectorizeVectorMetadataFilter = { - [`user-${query.user}`]: 1, - }; - console.log("Spaces", spaces); - - // Converting the query to a vector so that we can search for similar vectors - const queryAsVector = await store.embeddings.embedQuery(query.query); - const responses: VectorizeMatches = { matches: [], count: 0 }; - - console.log("hello world", spaces); - - // SLICED to 5 to avoid too many queries - for (const space of spaces.slice(0, 5)) { - console.log("space", space); - if (!space && spaces.length > 1) { - // it's possible for space list to be [undefined] so we only add space filter conditionally - filter[`space-${query.user}-${space}`] = 1; + if (!body.sources) { + const filter: VectorizeVectorMetadataFilter = { + [`user-${query.user}`]: 1, + }; + console.log("Spaces", spaces); + + // Converting the query to a vector so that we can search for similar vectors + const queryAsVector = await store.embeddings.embedQuery(query.query); + const responses: VectorizeMatches = { matches: [], count: 0 }; + + console.log("hello world", spaces); + + // SLICED to 5 to avoid too many queries + for (const space of spaces.slice(0, 5)) { + console.log("space", space); + if (!space && spaces.length > 1) { + // it's possible for space list to be [undefined] so we only add space filter conditionally + filter[`space-${query.user}-${space}`] = 1; + } + + // Because there's no OR operator in the filter, we have to make multiple queries + const resp = await c.env.VECTORIZE_INDEX.query(queryAsVector, { + topK: query.topK, + filter, + returnMetadata: true, + }); + + // Basically recreating the response object + if (resp.count > 0) { + responses.matches.push(...resp.matches); + responses.count += resp.count; + } } - // Because there's no OR operator in the filter, we have to make multiple queries - const resp = await c.env.VECTORIZE_INDEX.query(queryAsVector, { - topK: query.topK, - filter, - returnMetadata: true, - }); - - // Basically recreating the response object - if (resp.count > 0) { - responses.matches.push(...resp.matches); - responses.count += resp.count; - } - } + const minScore = Math.min(...responses.matches.map(({ score }) => score)); + const maxScore = Math.max(...responses.matches.map(({ score }) => score)); + + // We are "normalising" the scores - if all of them are on top, we want to make sure that + // we have a way to filter out the noise. + const normalizedData = responses.matches.map((data) => ({ + ...data, + normalizedScore: + maxScore !== minScore + ? 1 + ((data.score - minScore) / (maxScore - minScore)) * 98 + : 50, // If all scores are the same, set them to the middle of the scale + })); + + let highScoreData = normalizedData.filter( + ({ normalizedScore }) => normalizedScore > 50, + ); - const minScore = Math.min(...responses.matches.map(({ score }) => score)); - const maxScore = Math.max(...responses.matches.map(({ score }) => score)); - - // We are "normalising" the scores - if all of them are on top, we want to make sure that - // we have a way to filter out the noise. - const normalizedData = responses.matches.map((data) => ({ - ...data, - normalizedScore: - maxScore !== minScore - ? 1 + ((data.score - minScore) / (maxScore - minScore)) * 98 - : 50, // If all scores are the same, set them to the middle of the scale - })); - - let highScoreData = normalizedData.filter( - ({ normalizedScore }) => normalizedScore > 50, - ); + // If the normalsation is not done properly, we have a fallback to just get the + // top 3 scores + if (highScoreData.length === 0) { + highScoreData = normalizedData + .sort((a, b) => b.score - a.score) + .slice(0, 3); + } - // If the normalsation is not done properly, we have a fallback to just get the - // top 3 scores - if (highScoreData.length === 0) { - highScoreData = normalizedData - .sort((a, b) => b.score - a.score) - .slice(0, 3); - } + const sortedHighScoreData = highScoreData.sort( + (a, b) => b.normalizedScore - a.normalizedScore, + ); - const sortedHighScoreData = highScoreData.sort( - (a, b) => b.normalizedScore - a.normalizedScore, - ); + body.sources = { + normalizedData, + }; - // So this is kinda hacky, but the frontend needs to do 2 calls to get sources and chat. - // I think this is fine for now, but we can improve this later. - if (sourcesOnly) { - const idsAsStrings = sortedHighScoreData.map((dataPoint) => - dataPoint.id.toString(), - ); + // So this is kinda hacky, but the frontend needs to do 2 calls to get sources and chat. + // I think this is fine for now, but we can improve this later. + if (sourcesOnly) { + const idsAsStrings = sortedHighScoreData.map((dataPoint) => + dataPoint.id.toString(), + ); - const storedContent = await Promise.all( - idsAsStrings.map(async (id) => await c.env.KV.get(id)), - ); + const storedContent = await Promise.all( + idsAsStrings.map(async (id) => await c.env.KV.get(id)), + ); - const metadata = normalizedData.map((datapoint) => datapoint.metadata); + const metadata = normalizedData.map((datapoint) => datapoint.metadata); - return c.json({ ids: storedContent, metadata }); + return c.json({ ids: storedContent, metadata, normalizedData }); + } } - const preparedContext = normalizedData.map( + const preparedContext = body.sources.normalizedData.map( ({ metadata, score, normalizedScore }) => ({ context: `Website title: ${metadata!.title}\nDescription: ${metadata!.description}\nURL: ${metadata!.url}\nContent: ${metadata!.text}`, score, @@ -330,20 +342,28 @@ app.delete( }, ); -// ERROR #1 - this is the api that the editor uses, it is just a scrape off of /api/chat so you may check that out -app.get('/api/editorai', zValidator( - "query", - z.object({ - context: z.string(), - request: z.string(), - }), -), async (c)=> { - const { context, request } = c.req.valid("query"); - const { model } = await initQuery(c); +// ERROR #1 - this is the api that the editor uses, it is just a scrape off of /api/chat so you may check that out +app.get( + "/api/editorai", + zValidator( + "query", + z.object({ + context: z.string(), + request: z.string(), + }), + ), + async (c) => { + const { context, request } = c.req.valid("query"); + const { model } = await initQuery(c); - const response = await streamText({ model, prompt: `${request}-${context}`, maxTokens: 224 }); + const response = await streamText({ + model, + prompt: `${request}-${context}`, + maxTokens: 224, + }); - return response.toTextStreamResponse(); -}) + return response.toTextStreamResponse(); + }, +); export default app; diff --git a/apps/cf-ai-backend/src/types.ts b/apps/cf-ai-backend/src/types.ts index 417d6320..dc97777c 100644 --- a/apps/cf-ai-backend/src/types.ts +++ b/apps/cf-ai-backend/src/types.ts @@ -1,3 +1,4 @@ +import { sourcesZod } from "@repo/shared-types"; import { z } from "zod"; export type Env = { @@ -37,6 +38,7 @@ export const contentObj = z.object({ export const chatObj = z.object({ chatHistory: z.array(contentObj).optional(), + sources: sourcesZod.optional(), }); export const vectorObj = z.object({ diff --git a/apps/web/app/(dash)/chat/chatWindow.tsx b/apps/web/app/(dash)/chat/chatWindow.tsx index 8485d0b2..9a18cfe7 100644 --- a/apps/web/app/(dash)/chat/chatWindow.tsx +++ b/apps/web/app/(dash)/chat/chatWindow.tsx @@ -6,7 +6,7 @@ import QueryInput from "../home/queryinput"; import { cn } from "@repo/ui/lib/utils"; import { motion } from "framer-motion"; import { useRouter } from "next/navigation"; -import { ChatHistory } from "@repo/shared-types"; +import { ChatHistory, sourcesZod } from "@repo/shared-types"; import { Accordion, AccordionContent, @@ -20,15 +20,10 @@ import rehypeKatex from "rehype-katex"; import rehypeHighlight from "rehype-highlight"; import { code, p } from "./markdownRenderHelpers"; import { codeLanguageSubset } from "@/lib/constants"; -import { z } from "zod"; import { toast } from "sonner"; import Link from "next/link"; import { createChatObject } from "@/app/actions/doers"; -import { - ClipboardIcon, - ShareIcon, - SpeakerWaveIcon, -} from "@heroicons/react/24/outline"; +import { ClipboardIcon } from "@heroicons/react/24/outline"; import { SendIcon } from "lucide-react"; function ChatWindow({ @@ -83,11 +78,6 @@ function ChatWindow({ // TODO: handle this properly const sources = await sourcesFetch.json(); - const sourcesZod = z.object({ - ids: z.array(z.string()), - metadata: z.array(z.any()), - }); - const sourcesParsed = sourcesZod.safeParse(sources); if (!sourcesParsed.success) { @@ -100,7 +90,6 @@ function ChatWindow({ behavior: "smooth", }); - // Assuming this is part of a larger function within a React component const updateChatHistoryAndFetch = async () => { // Step 1: Update chat history with the assistant's response await new Promise((resolve) => { @@ -143,7 +132,7 @@ function ChatWindow({ `/api/chat?q=${query}&spaces=${spaces}&threadId=${threadId}`, { method: "POST", - body: JSON.stringify({ chatHistory }), + body: JSON.stringify({ chatHistory, sources: sourcesParsed.data }), }, ); diff --git a/apps/web/app/api/chat/route.ts b/apps/web/app/api/chat/route.ts index d0e53066..d1730baa 100644 --- a/apps/web/app/api/chat/route.ts +++ b/apps/web/app/api/chat/route.ts @@ -1,5 +1,10 @@ import { type NextRequest } from "next/server"; -import { ChatHistoryZod, convertChatHistoryList } from "@repo/shared-types"; +import { + ChatHistory, + ChatHistoryZod, + convertChatHistoryList, + SourcesFromApi, +} from "@repo/shared-types"; import { ensureAuth } from "../ensureAuth"; import { z } from "zod"; @@ -23,7 +28,11 @@ export async function POST(req: NextRequest) { const sourcesOnly = url.searchParams.get("sourcesOnly") ?? "false"; - const chatHistory = await req.json(); + const jsonRequest = (await req.json()) as { + chatHistory: ChatHistory[]; + sources: SourcesFromApi[] | undefined; + }; + const { chatHistory, sources } = jsonRequest; if (!query || query.trim.length < 0) { return new Response(JSON.stringify({ message: "Invalid query" }), { @@ -31,9 +40,7 @@ export async function POST(req: NextRequest) { }); } - const validated = z - .object({ chatHistory: z.array(ChatHistoryZod) }) - .safeParse(chatHistory ?? []); + const validated = z.array(ChatHistoryZod).safeParse(chatHistory ?? []); if (!validated.success) { return new Response( @@ -45,9 +52,7 @@ export async function POST(req: NextRequest) { ); } - const modelCompatible = await convertChatHistoryList( - validated.data.chatHistory, - ); + const modelCompatible = await convertChatHistoryList(validated.data); const resp = await fetch( `${process.env.BACKEND_BASE_URL}/api/chat?query=${query}&user=${session.user.id}&sourcesOnly=${sourcesOnly}&spaces=${spaces}`, @@ -59,12 +64,13 @@ export async function POST(req: NextRequest) { method: "POST", body: JSON.stringify({ chatHistory: modelCompatible, + sources, }), }, ); if (sourcesOnly == "true") { - const data = await resp.json(); + const data = (await resp.json()) as SourcesFromApi; return new Response(JSON.stringify(data), { status: 200 }); } diff --git a/packages/shared-types/index.ts b/packages/shared-types/index.ts index 318684b7..9d5caa40 100644 --- a/packages/shared-types/index.ts +++ b/packages/shared-types/index.ts @@ -59,3 +59,11 @@ export function convertChatHistoryList( return convertedChats; } + +export const sourcesZod = z.object({ + ids: z.array(z.string()), + metadata: z.array(z.any()), + normalizedData: z.array(z.any()).optional(), +}); + +export type SourcesFromApi = z.infer<typeof sourcesZod>; |