diff options
| author | Dhravya <[email protected]> | 2024-04-08 18:12:16 -0700 |
|---|---|---|
| committer | Dhravya <[email protected]> | 2024-04-08 18:12:16 -0700 |
| commit | f04fa3faf75c1b2c63f094632c15a528a98932c5 (patch) | |
| tree | f2831ca93153a0b7698cb0517a35c0601f1b5ed3 | |
| parent | made it functional (diff) | |
| download | supermemory-f04fa3faf75c1b2c63f094632c15a528a98932c5.tar.xz supermemory-f04fa3faf75c1b2c63f094632c15a528a98932c5.zip | |
setup for multi chat
| -rw-r--r-- | apps/cf-ai-backend/src/routes.ts | 3 | ||||
| -rw-r--r-- | apps/cf-ai-backend/src/routes/chat.ts | 91 | ||||
| -rw-r--r-- | apps/cf-ai-backend/src/routes/query.ts | 2 | ||||
| -rw-r--r-- | apps/web/src/app/api/chat/route.ts | 62 | ||||
| -rw-r--r-- | apps/web/src/components/ChatMessage.tsx | 50 | ||||
| -rw-r--r-- | apps/web/src/components/Main.tsx | 52 | ||||
| -rw-r--r-- | apps/web/types/memory.tsx | 5 |
7 files changed, 257 insertions, 8 deletions
diff --git a/apps/cf-ai-backend/src/routes.ts b/apps/cf-ai-backend/src/routes.ts index 4a2d2827..841f107a 100644 --- a/apps/cf-ai-backend/src/routes.ts +++ b/apps/cf-ai-backend/src/routes.ts @@ -2,6 +2,7 @@ import { CloudflareVectorizeStore } from '@langchain/cloudflare'; import * as apiAdd from './routes/add'; import * as apiQuery from "./routes/query" import * as apiAsk from "./routes/ask" +import * as apiChat from "./routes/chat" import { OpenAIEmbeddings } from './OpenAIEmbedder'; import { GenerativeModel } from '@google/generative-ai'; import { Request } from '@cloudflare/workers-types'; @@ -17,6 +18,8 @@ routeMap.set('/query', apiQuery); routeMap.set('/ask', apiAsk); +routeMap.set('/chat', apiChat); + // Add more route mappings as needed // routeMap.set('/api/otherRoute', { ... }); diff --git a/apps/cf-ai-backend/src/routes/chat.ts b/apps/cf-ai-backend/src/routes/chat.ts new file mode 100644 index 00000000..1082998f --- /dev/null +++ b/apps/cf-ai-backend/src/routes/chat.ts @@ -0,0 +1,91 @@ +import { Content, GenerativeModel } from "@google/generative-ai"; +import { OpenAIEmbeddings } from "../OpenAIEmbedder"; +import { CloudflareVectorizeStore } from "@langchain/cloudflare"; +import { Request } from "@cloudflare/workers-types"; + +export async function POST(request: Request, _: CloudflareVectorizeStore, embeddings: OpenAIEmbeddings, model: GenerativeModel, env?: Env) { + const queryparams = new URL(request.url).searchParams; + const query = queryparams.get("q"); + const topK = parseInt(queryparams.get("topK") ?? "5"); + const user = queryparams.get("user") + const space = queryparams.get("space") + + const sourcesOnly = (queryparams.get("sourcesOnly") ?? "false") + + if (!user) { + return new Response(JSON.stringify({ message: "Invalid User" }), { status: 400 }); + } + + if (!query) { + return new Response(JSON.stringify({ message: "Invalid Query" }), { status: 400 }); + } + + const filter: VectorizeVectorMetadataFilter = { + user + } + + if (space) { + filter.space + } + + const queryAsVector = await embeddings.embedQuery(query); + + const resp = await env!.VECTORIZE_INDEX.query(queryAsVector, { + topK, + filter + }); + + if (resp.count === 0) { + return new Response(JSON.stringify({ message: "No Results Found" }), { status: 404 }); + } + + const highScoreIds = resp.matches.filter(({ score }) => score > 0.3).map(({ id }) => id) + + if (sourcesOnly === "true") { + return new Response(JSON.stringify({ ids: highScoreIds }), { status: 200 }); + } + + const vec = await env!.VECTORIZE_INDEX.getByIds(highScoreIds) + + const preparedContext = vec.slice(0, 3).map(({ metadata }) => `Website title: ${metadata!.title}\nDescription: ${metadata!.description}\nURL: ${metadata!.url}\nContent: ${metadata!.text}`).join("\n\n"); + + const body = await request.json() as { + chatHistory?: Content[] + }; + + const defaultHistory = [ + { + role: "user", + parts: [{ text: `You are an agent that summarizes a page based on the query. don't say 'based on the context'. I expect you to be like a 'Second Brain'. you will be provided with the context (old saved posts) and questions. Answer accordingly. Answer in markdown format` }], + }, + { + role: "model", + parts: [{ text: "Ok, I am a personal assistant, and will act as a second brain to help with user's queries." }], + }, + ] as Content[]; + + const chat = model.startChat({ + history: [...defaultHistory, ...(body.chatHistory ?? [])], + }); + + const prompt = `Context:\n${preparedContext}\n\nQuestion: ${query}\nAnswer:`; + + const output = await chat.sendMessageStream(prompt); + + const response = new Response( + new ReadableStream({ + async start(controller) { + const converter = new TextEncoder(); + for await (const chunk of output.stream) { + const chunkText = await chunk.text(); + const encodedChunk = converter.encode("data: " + JSON.stringify({ "response": chunkText }) + "\n\n"); + controller.enqueue(encodedChunk); + } + const doneChunk = converter.encode("data: [DONE]"); + controller.enqueue(doneChunk); + controller.close(); + } + }) + ); + return response; +} diff --git a/apps/cf-ai-backend/src/routes/query.ts b/apps/cf-ai-backend/src/routes/query.ts index e02d0150..be237d7d 100644 --- a/apps/cf-ai-backend/src/routes/query.ts +++ b/apps/cf-ai-backend/src/routes/query.ts @@ -36,7 +36,7 @@ export async function GET(request: Request, _: CloudflareVectorizeStore, embeddi }); if (resp.count === 0) { - return new Response(JSON.stringify({ message: "No Results Found" }), { status: 400 }); + return new Response(JSON.stringify({ message: "No Results Found" }), { status: 404 }); } const highScoreIds = resp.matches.filter(({ score }) => score > 0.3).map(({ id }) => id) diff --git a/apps/web/src/app/api/chat/route.ts b/apps/web/src/app/api/chat/route.ts new file mode 100644 index 00000000..2cb03186 --- /dev/null +++ b/apps/web/src/app/api/chat/route.ts @@ -0,0 +1,62 @@ +import { db } from "@/server/db"; +import { eq } from "drizzle-orm"; +import { sessions, users } from "@/server/db/schema"; +import { type NextRequest, NextResponse } from "next/server"; +import { env } from "@/env"; +import { ChatHistory } from "../../../../types/memory"; + +export const runtime = "edge"; + +export async function POST(req: NextRequest) { + const token = req.cookies.get("next-auth.session-token")?.value ?? req.cookies.get("__Secure-authjs.session-token")?.value ?? req.cookies.get("authjs.session-token")?.value ?? req.headers.get("Authorization")?.replace("Bearer ", ""); + + const sessionData = await db.select().from(sessions).where(eq(sessions.sessionToken, token!)) + + if (!sessionData || sessionData.length === 0) { + return new Response(JSON.stringify({ message: "Invalid Key, session not found." }), { status: 404 }); + } + + const user = await db.select().from(users).where(eq(users.id, sessionData[0].userId)).limit(1) + + if (!user || user.length === 0) { + return NextResponse.json({ message: "Invalid Key, session not found." }, { status: 404 }); + } + + const session = { session: sessionData[0], user: user[0] } + + const query = new URL(req.url).searchParams.get("q"); + const sourcesOnly = new URL(req.url).searchParams.get("sourcesOnly") ?? "false"; + + const chatHistory = await req.json() as { + chatHistory: ChatHistory[] + }; + + + if (!query) { + return new Response(JSON.stringify({ message: "Invalid query" }), { status: 400 }); + } + + const resp = await fetch(`https://cf-ai-backend.dhravya.workers.dev/chat?q=${query}&user=${session.user.email ?? session.user.name}&sourcesOnly=${sourcesOnly}`, { + headers: { + "X-Custom-Auth-Key": env.BACKEND_SECURITY_KEY, + }, + method: "POST", + body: JSON.stringify({ + chatHistory + }) + }) + + console.log(resp.status) + + if (resp.status !== 200 || !resp.ok) { + const errorData = await resp.json(); + console.log(errorData) + return new Response(JSON.stringify({ message: "Error in CF function", error: errorData }), { status: resp.status }); + } + + // Stream the response back to the client + const { readable, writable } = new TransformStream(); + resp && resp.body!.pipeTo(writable); + + return new Response(readable, { status: 200 }); +}
\ No newline at end of file diff --git a/apps/web/src/components/ChatMessage.tsx b/apps/web/src/components/ChatMessage.tsx new file mode 100644 index 00000000..a8199758 --- /dev/null +++ b/apps/web/src/components/ChatMessage.tsx @@ -0,0 +1,50 @@ +import React from 'react'; +import { Avatar, AvatarFallback, AvatarImage } from './ui/avatar'; +import { User } from 'next-auth'; +import { User2 } from 'lucide-react'; +import Image from 'next/image'; + +function ChatMessage({ + message, + user, +}: { + message: string; + user: User | 'ai'; +}) { + return ( + <div className="flex flex-col gap-4"> + <div + className={`font-bold ${!(user === 'ai') && 'text-xl '} flex flex-col md:flex-row items-center gap-4`} + > + <Avatar> + {user === 'ai' ? ( + <Image + src="/logo.png" + width={48} + height={48} + alt="AI" + className="rounded-md w-12 h-12" + /> + ) : user?.image ? ( + <> + <AvatarImage + className="h-6 w-6 rounded-lg" + src={user?.image} + alt="user pfp" + /> + <AvatarFallback> + {user?.name?.split(' ').map((n) => n[0])}{' '} + </AvatarFallback> + </> + ) : ( + <User2 strokeWidth={1.3} className="h-6 w-6" /> + )} + </Avatar> + <div className="ml-4">{message}</div> + </div> + <div className="w-full h-0.5 bg-gray-700 my-2 md:my-0 md:mx-4 mt-8"></div> + </div> + ); +} + +export { ChatMessage }; diff --git a/apps/web/src/components/Main.tsx b/apps/web/src/components/Main.tsx index 86679dcf..a9111494 100644 --- a/apps/web/src/components/Main.tsx +++ b/apps/web/src/components/Main.tsx @@ -1,5 +1,5 @@ 'use client'; -import { useEffect, useRef, useState } from 'react'; +import { useCallback, useEffect, useRef, useState } from 'react'; import { FilterCombobox } from './Sidebar/FilterCombobox'; import { Textarea2 } from './ui/textarea'; import { ArrowRight } from 'lucide-react'; @@ -8,6 +8,9 @@ import useViewport from '@/hooks/useViewport'; import { motion } from 'framer-motion'; import { cn } from '@/lib/utils'; import SearchResults from './SearchResults'; +import { ChatHistory } from '../../types/memory'; +import { ChatMessage } from './ChatMessage'; +import { useSession } from 'next-auth/react'; function supportsDVH() { try { @@ -24,7 +27,34 @@ export default function Main({ sidebarOpen }: { sidebarOpen: boolean }) { const [searchResults, setSearchResults] = useState<string[]>([]); const [isAiLoading, setIsAiLoading] = useState(false); + const { data: session } = useSession(); + + // Variable to keep track of the chat history in this session + const [chatHistory, setChatHistory] = useState<ChatHistory[]>([]); + + // TEMPORARY solution: Basically this is to just keep track of the sources used for each chat message + // Not a great solution + const [chatTextSourceDict, setChatTextSourceDict] = useState< + Record<string, string> + >({}); + + // helper function to append a new msg + const appendToChatHistory = useCallback( + (role: 'user' | 'model', content: string) => { + setChatHistory((prev) => [ + ...prev, + { + role, + parts: [{ text: content }], + }, + ]); + }, + [], + ); + + // This is the streamed AI response we get from the server. const [aiResponse, setAIResponse] = useState(''); + const [toBeParsed, setToBeParsed] = useState(''); const textArea = useRef<HTMLTextAreaElement>(null); @@ -153,17 +183,28 @@ export default function Main({ sidebarOpen }: { sidebarOpen: boolean }) { hide ? '' : 'main-hidden', )} > + <div className="flex flex-col w-full"> + {chatHistory.map((chat, index) => ( + <ChatMessage + key={index} + message={chat.parts[0].text} + user={chat.role === 'model' ? 'ai' : session?.user!} + /> + ))} + </div> <h1 className="text-rgray-11 mt-auto w-full text-center text-3xl md:mt-0"> Ask your Second brain </h1> - <form onSubmit={async (e) => await getSearchResults(e)}> + <form + className="mt-auto h-max min-h-[3em] w-full resize-y flex-row items-start justify-center overflow-none py-5 md:mt-0 md:h-[20vh] md:resize-none md:flex-col md:items-center md:justify-center md:p-2 md:pb-2 md:pt-2" + onSubmit={async (e) => await getSearchResults(e)} + > <Textarea2 ref={textArea} - className="mt-auto h-max max-h-[30em] min-h-[3em] resize-y flex-row items-start justify-center overflow-auto py-5 md:mt-0 md:h-[20vh] md:resize-none md:flex-col md:items-center md:justify-center md:p-2 md:pb-2 md:pt-2" textAreaProps={{ placeholder: 'Ask your SuperMemory...', className: - 'h-auto overflow-auto md:h-full md:resize-none text-lg py-0 px-2 md:py-0 md:p-5 resize-y text-rgray-11 w-full min-h-[1em]', + 'h-auto overflow-auto md:h-full md:resize-none text-lg py-0 px-2 pt-2 md:py-0 md:p-5 resize-y text-rgray-11 w-full min-h-[1em]', value, autoFocus: true, onChange: (e) => setValue(e.target.value), @@ -181,9 +222,6 @@ export default function Main({ sidebarOpen }: { sidebarOpen: boolean }) { </div> </Textarea2> </form> - {searchResults && ( - <SearchResults aiResponse={aiResponse} sources={searchResults} /> - )} {width <= 768 && <MemoryDrawer hide={hide} />} </motion.main> ); diff --git a/apps/web/types/memory.tsx b/apps/web/types/memory.tsx index f184615a..e71e92c9 100644 --- a/apps/web/types/memory.tsx +++ b/apps/web/types/memory.tsx @@ -53,3 +53,8 @@ export type CollectedSpaces = { title: string; content: StoredContent[]; }; + +export type ChatHistory = { + role: 'user' | 'model'; + parts: [{ text: string }]; +}; |