diff options
| author | Dhravya <[email protected]> | 2024-03-31 22:44:22 -0700 |
|---|---|---|
| committer | Dhravya <[email protected]> | 2024-03-31 22:44:22 -0700 |
| commit | c40202fd04cb2541d88d2a25f2fc5dfc2128508c (patch) | |
| tree | 1fa82723b3d1d25b694cfa9e64f67308e815fb69 | |
| parent | Merge branch 'main' of https://github.com/Dhravya/supermemory into new-ui (diff) | |
| download | supermemory-c40202fd04cb2541d88d2a25f2fc5dfc2128508c.tar.xz supermemory-c40202fd04cb2541d88d2a25f2fc5dfc2128508c.zip | |
added gemini streaming in cf-ai-backend
| -rw-r--r-- | apps/cf-ai-backend/src/index.ts | 77 | ||||
| -rw-r--r-- | apps/web/src/components/QueryAI.tsx | 5 | ||||
| -rw-r--r-- | package.json | 1 |
3 files changed, 53 insertions, 30 deletions
diff --git a/apps/cf-ai-backend/src/index.ts b/apps/cf-ai-backend/src/index.ts index fc7241a0..5f45eeb0 100644 --- a/apps/cf-ai-backend/src/index.ts +++ b/apps/cf-ai-backend/src/index.ts @@ -7,15 +7,15 @@ import type { import { CloudflareVectorizeStore, } from "@langchain/cloudflare"; -import { Ai } from '@cloudflare/ai'; import { OpenAIEmbeddings } from "./OpenAIEmbedder"; -import { AiTextGenerationOutput } from "@cloudflare/ai/dist/ai/tasks/text-generation"; +import { GoogleGenerativeAI } from "@google/generative-ai"; export interface Env { VECTORIZE_INDEX: VectorizeIndex; AI: Fetcher; SECURITY_KEY: string; OPENAI_API_KEY: string; + GOOGLE_AI_API_KEY: string; } @@ -38,7 +38,10 @@ export default { const store = new CloudflareVectorizeStore(embeddings, { index: env.VECTORIZE_INDEX, }); - const ai = new Ai(env.AI) + // const ai = new Ai(env.AI) + + const genAI = new GoogleGenerativeAI(env.GOOGLE_AI_API_KEY); + const model = genAI.getGenerativeModel({ model: "gemini-pro" }); if (pathname === "/add" && request.method === "POST") { @@ -119,22 +122,27 @@ export default { return new Response(JSON.stringify({ message: "No Results Found" }), { status: 400 }); } - const metadatas = vec.map(({ metadata }) => metadata) - - console.log(metadatas) - - // TODO: TAKE ALL THE HIGH SCORED IDS INTO CONSIDERATION - const output: AiTextGenerationOutput = await ai.run('@hf/thebloke/mistral-7b-instruct-v0.1-awq', { - prompt: `You are an agent that summarizes a page based on the query. Be direct and concise, don't say 'based on the context'.\n\n Context:\n${vec[0].metadata!.text} \nAnswer this question based on the context. Question: ${query}\nAnswer:`, - stream: true - }) as ReadableStream - - - return new Response(output, { - headers: { - "content-type": "text/event-stream", - }, - }); + const preparedContext = vec.slice(0, 3).map(({ metadata }) => `Website title: ${metadata!.title}\nDescription: ${metadata!.description}\nURL: ${metadata!.url}\nContent: ${metadata!.text}`).join("\n\n"); + + const prompt = `You are an agent that summarizes a page based on the query. Be direct and concise, don't say 'based on the context'.\n\n Context:\n${preparedContext} \nAnswer this question based on the context. Question: ${query}\nAnswer:` + const output = await model.generateContentStream(prompt); + + const response = new Response( + new ReadableStream({ + async start(controller) { + const converter = new TextEncoder(); + for await (const chunk of output.stream) { + const chunkText = await chunk.text(); + const encodedChunk = converter.encode("data: " + JSON.stringify({ "response": chunkText }) + "\n\n"); + controller.enqueue(encodedChunk); + } + const doneChunk = converter.encode("data: [DONE]"); + controller.enqueue(doneChunk); + controller.close(); + } + }) + ); + return response; } else if (pathname === "/ask" && request.method === "POST") { @@ -146,17 +154,26 @@ export default { return new Response(JSON.stringify({ message: "Invalid Page Content" }), { status: 400 }); } - const output: AiTextGenerationOutput = await ai.run('@hf/thebloke/mistral-7b-instruct-v0.1-awq', { - prompt: `You are an agent that answers a question based on the query. Be direct and concise, don't say 'based on the context'.\n\n Context:\n${body.query} \nAnswer this question based on the context. Question: ${body.query}\nAnswer:`, - stream: true - }) as ReadableStream - - - return new Response(output, { - headers: { - "content-type": "text/event-stream", - }, - }); + const prompt = `You are an agent that answers a question based on the query. Be direct and concise, don't say 'based on the context'.\n\n Context:\n${body.query} \nAnswer this question based on the context. Question: ${body.query}\nAnswer:` + const output = await model.generateContentStream(prompt); + + const response = new Response( + new ReadableStream({ + async start(controller) { + const converter = new TextEncoder(); + for await (const chunk of output.stream) { + const chunkText = await chunk.text(); + console.log(chunkText); + const encodedChunk = converter.encode("data: " + JSON.stringify({ "response": chunkText }) + "\n\n"); + controller.enqueue(encodedChunk); + } + const doneChunk = converter.encode("data: [DONE] \n\n"); + controller.enqueue(doneChunk); + controller.close(); + } + }) + ); + return response; } return new Response(JSON.stringify({ message: "Invalid Request" }), { status: 400 }); diff --git a/apps/web/src/components/QueryAI.tsx b/apps/web/src/components/QueryAI.tsx index 811dd899..3cb14178 100644 --- a/apps/web/src/components/QueryAI.tsx +++ b/apps/web/src/components/QueryAI.tsx @@ -82,6 +82,11 @@ function QueryAI() { const response = await fetch(`/api/query?q=${input}`); + if (response.status !== 200) { + setIsAiLoading(false); + return; + } + if (response.body) { let reader = response.body.getReader(); let decoder = new TextDecoder('utf-8'); diff --git a/package.json b/package.json index 5ba4eea6..145a2384 100644 --- a/package.json +++ b/package.json @@ -39,6 +39,7 @@ "@cloudflare/ai": "^1.0.52", "@cloudflare/next-on-pages-next-dev": "^0.0.1", "@crxjs/vite-plugin": "^1.0.14", + "@google/generative-ai": "^0.3.1", "@heroicons/react": "^2.1.1", "@langchain/cloudflare": "^0.0.3", "@radix-ui/colors": "^3.0.0", |