aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDhravya <[email protected]>2024-03-31 22:44:22 -0700
committerDhravya <[email protected]>2024-03-31 22:44:22 -0700
commitc40202fd04cb2541d88d2a25f2fc5dfc2128508c (patch)
tree1fa82723b3d1d25b694cfa9e64f67308e815fb69
parentMerge branch 'main' of https://github.com/Dhravya/supermemory into new-ui (diff)
downloadsupermemory-c40202fd04cb2541d88d2a25f2fc5dfc2128508c.tar.xz
supermemory-c40202fd04cb2541d88d2a25f2fc5dfc2128508c.zip
added gemini streaming in cf-ai-backend
-rw-r--r--apps/cf-ai-backend/src/index.ts77
-rw-r--r--apps/web/src/components/QueryAI.tsx5
-rw-r--r--package.json1
3 files changed, 53 insertions, 30 deletions
diff --git a/apps/cf-ai-backend/src/index.ts b/apps/cf-ai-backend/src/index.ts
index fc7241a0..5f45eeb0 100644
--- a/apps/cf-ai-backend/src/index.ts
+++ b/apps/cf-ai-backend/src/index.ts
@@ -7,15 +7,15 @@ import type {
import {
CloudflareVectorizeStore,
} from "@langchain/cloudflare";
-import { Ai } from '@cloudflare/ai';
import { OpenAIEmbeddings } from "./OpenAIEmbedder";
-import { AiTextGenerationOutput } from "@cloudflare/ai/dist/ai/tasks/text-generation";
+import { GoogleGenerativeAI } from "@google/generative-ai";
export interface Env {
VECTORIZE_INDEX: VectorizeIndex;
AI: Fetcher;
SECURITY_KEY: string;
OPENAI_API_KEY: string;
+ GOOGLE_AI_API_KEY: string;
}
@@ -38,7 +38,10 @@ export default {
const store = new CloudflareVectorizeStore(embeddings, {
index: env.VECTORIZE_INDEX,
});
- const ai = new Ai(env.AI)
+ // const ai = new Ai(env.AI)
+
+ const genAI = new GoogleGenerativeAI(env.GOOGLE_AI_API_KEY);
+ const model = genAI.getGenerativeModel({ model: "gemini-pro" });
if (pathname === "/add" && request.method === "POST") {
@@ -119,22 +122,27 @@ export default {
return new Response(JSON.stringify({ message: "No Results Found" }), { status: 400 });
}
- const metadatas = vec.map(({ metadata }) => metadata)
-
- console.log(metadatas)
-
- // TODO: TAKE ALL THE HIGH SCORED IDS INTO CONSIDERATION
- const output: AiTextGenerationOutput = await ai.run('@hf/thebloke/mistral-7b-instruct-v0.1-awq', {
- prompt: `You are an agent that summarizes a page based on the query. Be direct and concise, don't say 'based on the context'.\n\n Context:\n${vec[0].metadata!.text} \nAnswer this question based on the context. Question: ${query}\nAnswer:`,
- stream: true
- }) as ReadableStream
-
-
- return new Response(output, {
- headers: {
- "content-type": "text/event-stream",
- },
- });
+ const preparedContext = vec.slice(0, 3).map(({ metadata }) => `Website title: ${metadata!.title}\nDescription: ${metadata!.description}\nURL: ${metadata!.url}\nContent: ${metadata!.text}`).join("\n\n");
+
+ const prompt = `You are an agent that summarizes a page based on the query. Be direct and concise, don't say 'based on the context'.\n\n Context:\n${preparedContext} \nAnswer this question based on the context. Question: ${query}\nAnswer:`
+ const output = await model.generateContentStream(prompt);
+
+ const response = new Response(
+ new ReadableStream({
+ async start(controller) {
+ const converter = new TextEncoder();
+ for await (const chunk of output.stream) {
+ const chunkText = await chunk.text();
+ const encodedChunk = converter.encode("data: " + JSON.stringify({ "response": chunkText }) + "\n\n");
+ controller.enqueue(encodedChunk);
+ }
+ const doneChunk = converter.encode("data: [DONE]");
+ controller.enqueue(doneChunk);
+ controller.close();
+ }
+ })
+ );
+ return response;
}
else if (pathname === "/ask" && request.method === "POST") {
@@ -146,17 +154,26 @@ export default {
return new Response(JSON.stringify({ message: "Invalid Page Content" }), { status: 400 });
}
- const output: AiTextGenerationOutput = await ai.run('@hf/thebloke/mistral-7b-instruct-v0.1-awq', {
- prompt: `You are an agent that answers a question based on the query. Be direct and concise, don't say 'based on the context'.\n\n Context:\n${body.query} \nAnswer this question based on the context. Question: ${body.query}\nAnswer:`,
- stream: true
- }) as ReadableStream
-
-
- return new Response(output, {
- headers: {
- "content-type": "text/event-stream",
- },
- });
+ const prompt = `You are an agent that answers a question based on the query. Be direct and concise, don't say 'based on the context'.\n\n Context:\n${body.query} \nAnswer this question based on the context. Question: ${body.query}\nAnswer:`
+ const output = await model.generateContentStream(prompt);
+
+ const response = new Response(
+ new ReadableStream({
+ async start(controller) {
+ const converter = new TextEncoder();
+ for await (const chunk of output.stream) {
+ const chunkText = await chunk.text();
+ console.log(chunkText);
+ const encodedChunk = converter.encode("data: " + JSON.stringify({ "response": chunkText }) + "\n\n");
+ controller.enqueue(encodedChunk);
+ }
+ const doneChunk = converter.encode("data: [DONE] \n\n");
+ controller.enqueue(doneChunk);
+ controller.close();
+ }
+ })
+ );
+ return response;
}
return new Response(JSON.stringify({ message: "Invalid Request" }), { status: 400 });
diff --git a/apps/web/src/components/QueryAI.tsx b/apps/web/src/components/QueryAI.tsx
index 811dd899..3cb14178 100644
--- a/apps/web/src/components/QueryAI.tsx
+++ b/apps/web/src/components/QueryAI.tsx
@@ -82,6 +82,11 @@ function QueryAI() {
const response = await fetch(`/api/query?q=${input}`);
+ if (response.status !== 200) {
+ setIsAiLoading(false);
+ return;
+ }
+
if (response.body) {
let reader = response.body.getReader();
let decoder = new TextDecoder('utf-8');
diff --git a/package.json b/package.json
index 5ba4eea6..145a2384 100644
--- a/package.json
+++ b/package.json
@@ -39,6 +39,7 @@
"@cloudflare/ai": "^1.0.52",
"@cloudflare/next-on-pages-next-dev": "^0.0.1",
"@crxjs/vite-plugin": "^1.0.14",
+ "@google/generative-ai": "^0.3.1",
"@heroicons/react": "^2.1.1",
"@langchain/cloudflare": "^0.0.3",
"@radix-ui/colors": "^3.0.0",