diff options
| author | Dhravya <[email protected]> | 2024-04-02 18:11:01 -0700 |
|---|---|---|
| committer | Dhravya <[email protected]> | 2024-04-02 18:11:01 -0700 |
| commit | 4380ea824a274a35febf12ab9cd8d0cfc8544c1c (patch) | |
| tree | 734e04c60f8a709086aec6270050cc0a2164298c | |
| parent | gracefully end stream (diff) | |
| download | supermemory-4380ea824a274a35febf12ab9cd8d0cfc8544c1c.tar.xz supermemory-4380ea824a274a35febf12ab9cd8d0cfc8544c1c.zip | |
refactored ai backend code for composability
| -rw-r--r-- | apps/cf-ai-backend/src/env.d.ts | 7 | ||||
| -rw-r--r-- | apps/cf-ai-backend/src/index.ts | 154 | ||||
| -rw-r--r-- | apps/cf-ai-backend/src/routes.ts | 29 | ||||
| -rw-r--r-- | apps/cf-ai-backend/src/routes/add.ts | 36 | ||||
| -rw-r--r-- | apps/cf-ai-backend/src/routes/ask.ts | 35 | ||||
| -rw-r--r-- | apps/cf-ai-backend/src/routes/query.ts | 72 |
6 files changed, 192 insertions, 141 deletions
diff --git a/apps/cf-ai-backend/src/env.d.ts b/apps/cf-ai-backend/src/env.d.ts new file mode 100644 index 00000000..acbd6c43 --- /dev/null +++ b/apps/cf-ai-backend/src/env.d.ts @@ -0,0 +1,7 @@ +interface Env { + VECTORIZE_INDEX: VectorizeIndex; + AI: Fetcher; + SECURITY_KEY: string; + OPENAI_API_KEY: string; + GOOGLE_AI_API_KEY: string; +} diff --git a/apps/cf-ai-backend/src/index.ts b/apps/cf-ai-backend/src/index.ts index f4843601..f55c465b 100644 --- a/apps/cf-ai-backend/src/index.ts +++ b/apps/cf-ai-backend/src/index.ts @@ -9,27 +9,18 @@ import { } from "@langchain/cloudflare"; import { OpenAIEmbeddings } from "./OpenAIEmbedder"; import { GoogleGenerativeAI } from "@google/generative-ai"; - -export interface Env { - VECTORIZE_INDEX: VectorizeIndex; - AI: Fetcher; - SECURITY_KEY: string; - OPENAI_API_KEY: string; - GOOGLE_AI_API_KEY: string; -} - +import routeMap from "./routes"; function isAuthorized(request: Request, env: Env): boolean { return request.headers.get('X-Custom-Auth-Key') === env.SECURITY_KEY; } export default { - async fetch(request: Request, env: Env) { + async fetch(request: Request, env: Env, ctx: ExecutionContext) { if (!isAuthorized(request, env)) { return new Response('Unauthorized', { status: 401 }); } - const pathname = new URL(request.url).pathname; const embeddings = new OpenAIEmbeddings({ apiKey: env.OPENAI_API_KEY, modelName: 'text-embedding-3-small', @@ -40,143 +31,24 @@ export default { }); const genAI = new GoogleGenerativeAI(env.GOOGLE_AI_API_KEY); - const model = genAI.getGenerativeModel({ model: "gemini-pro" }); - - // TODO: Add /chat endpoint to chat with the AI in a conversational manner - if (pathname === "/add" && request.method === "POST") { - - const body = await request.json() as { - pageContent: string, - title?: string, - description?: string, - url: string, - user: string - }; - - - if (!body.pageContent || !body.url) { - return new Response(JSON.stringify({ message: "Invalid Page Content" }), { status: 400 }); - } - const newPageContent = `Title: ${body.title}\nDescription: ${body.description}\nURL: ${body.url}\nContent: ${body.pageContent}` - - - await store.addDocuments([ - { - pageContent: newPageContent, - metadata: { - title: body.title ?? "", - description: body.description ?? "", - url: body.url, - user: body.user, - }, - }, - ], { - ids: [`${body.url}`] - }) - - return new Response(JSON.stringify({ message: "Document Added" }), { status: 200 }); - } - - else if (pathname === "/query" && request.method === "GET") { - const queryparams = new URL(request.url).searchParams; - const query = queryparams.get("q"); - const topK = parseInt(queryparams.get("topK") ?? "5"); - const user = queryparams.get("user") - - const sourcesOnly = (queryparams.get("sourcesOnly") ?? "false") - if (!user) { - return new Response(JSON.stringify({ message: "Invalid User" }), { status: 400 }); - } - - if (!query) { - return new Response(JSON.stringify({ message: "Invalid Query" }), { status: 400 }); - } - - const filter: VectorizeVectorMetadataFilter = { - user: { - $eq: user - } - } - - const queryAsVector = await embeddings.embedQuery(query); - - const resp = await env.VECTORIZE_INDEX.query(queryAsVector, { - topK, - filter - }); - - if (resp.count === 0) { - return new Response(JSON.stringify({ message: "No Results Found" }), { status: 400 }); - } - - const highScoreIds = resp.matches.filter(({ score }) => score > 0.3).map(({ id }) => id) - - if (sourcesOnly === "true") { - return new Response(JSON.stringify({ ids: highScoreIds }), { status: 200 }); - } - - const vec = await env.VECTORIZE_INDEX.getByIds(highScoreIds) - - if (vec.length === 0 || !vec[0].metadata) { - return new Response(JSON.stringify({ message: "No Results Found" }), { status: 400 }); - } + const model = genAI.getGenerativeModel({ model: "gemini-pro" }); - const preparedContext = vec.slice(0, 3).map(({ metadata }) => `Website title: ${metadata!.title}\nDescription: ${metadata!.description}\nURL: ${metadata!.url}\nContent: ${metadata!.text}`).join("\n\n"); + const url = new URL(request.url); + const path = url.pathname; + const method = request.method.toUpperCase(); - const prompt = `You are an agent that summarizes a page based on the query. Be direct and concise, don't say 'based on the context'.\n\n Context:\n${preparedContext} \nAnswer this question based on the context. Question: ${query}\nAnswer:` - const output = await model.generateContentStream(prompt); + const routeHandlers = routeMap.get(path); - const response = new Response( - new ReadableStream({ - async start(controller) { - const converter = new TextEncoder(); - for await (const chunk of output.stream) { - const chunkText = await chunk.text(); - const encodedChunk = converter.encode("data: " + JSON.stringify({ "response": chunkText }) + "\n\n"); - controller.enqueue(encodedChunk); - } - const doneChunk = converter.encode("data: [DONE]"); - controller.enqueue(doneChunk); - controller.close(); - } - }) - ); - return response; + if (!routeHandlers) { + return new Response('Not Found', { status: 404 }); } - else if (pathname === "/ask" && request.method === "POST") { - const body = await request.json() as { - query: string - }; - - if (!body.query) { - return new Response(JSON.stringify({ message: "Invalid Page Content" }), { status: 400 }); - } + const handler = routeHandlers[method]; - const prompt = `You are an agent that answers a question based on the query. Be direct and concise, don't say 'based on the context'.\n\n Context:\n${body.query} \nAnswer this question based on the context. Question: ${body.query}\nAnswer:` - const output = await model.generateContentStream(prompt); - - const response = new Response( - new ReadableStream({ - async start(controller) { - const converter = new TextEncoder(); - for await (const chunk of output.stream) { - const chunkText = await chunk.text(); - console.log(chunkText); - const encodedChunk = converter.encode("data: " + JSON.stringify({ "response": chunkText }) + "\n\n"); - controller.enqueue(encodedChunk); - } - const doneChunk = converter.encode("data: [DONE]"); - controller.enqueue(doneChunk); - controller.close(); - } - }) - ); - return response; + if (!handler) { + return new Response('Method Not Allowed', { status: 405 }); } - - return new Response(JSON.stringify({ message: "Invalid Request" }), { status: 400 }); - + return await handler(request, store, embeddings, model, env, ctx); }, }; diff --git a/apps/cf-ai-backend/src/routes.ts b/apps/cf-ai-backend/src/routes.ts new file mode 100644 index 00000000..376c313e --- /dev/null +++ b/apps/cf-ai-backend/src/routes.ts @@ -0,0 +1,29 @@ +import { CloudflareVectorizeStore } from '@langchain/cloudflare'; +import * as apiAdd from './routes/add'; +import * as apiQuery from "./routes/query" +import * as apiAsk from "./routes/ask" +import { OpenAIEmbeddings } from './OpenAIEmbedder'; +import { GenerativeModel } from '@google/generative-ai'; +import { Request } from '@cloudflare/workers-types'; + + +type RouteHandler = (request: Request, store: CloudflareVectorizeStore, embeddings: OpenAIEmbeddings, model: GenerativeModel, env: Env, ctx?: ExecutionContext) => Promise<Response>; + +const routeMap = new Map<string, Record<string, RouteHandler>>(); + +routeMap.set('/add', { + POST: apiAdd.POST, +}); + +routeMap.set('/query', { + GET: apiQuery.GET, +}); + +routeMap.set('/ask', { + POST: apiAsk.POST, +}); + +// Add more route mappings as needed +// routeMap.set('/api/otherRoute', { ... }); + +export default routeMap; diff --git a/apps/cf-ai-backend/src/routes/add.ts b/apps/cf-ai-backend/src/routes/add.ts new file mode 100644 index 00000000..fb4e7121 --- /dev/null +++ b/apps/cf-ai-backend/src/routes/add.ts @@ -0,0 +1,36 @@ +import { Request } from "@cloudflare/workers-types"; +import { type CloudflareVectorizeStore } from "@langchain/cloudflare"; + +export async function POST(request: Request, store: CloudflareVectorizeStore) { + const body = await request.json() as { + pageContent: string, + title?: string, + description?: string, + category?: string, + url: string, + user: string + }; + + if (!body.pageContent || !body.url) { + return new Response(JSON.stringify({ message: "Invalid Page Content" }), { status: 400 }); + } + const newPageContent = `Title: ${body.title}\nDescription: ${body.description}\nURL: ${body.url}\nContent: ${body.pageContent}` + + + await store.addDocuments([ + { + pageContent: newPageContent, + metadata: { + title: body.title ?? "", + description: body.description ?? "", + category: body.category ?? "", + url: body.url, + user: body.user, + }, + }, + ], { + ids: [`${body.url}`] + }) + + return new Response(JSON.stringify({ message: "Document Added" }), { status: 200 }); +} diff --git a/apps/cf-ai-backend/src/routes/ask.ts b/apps/cf-ai-backend/src/routes/ask.ts new file mode 100644 index 00000000..1c48dde8 --- /dev/null +++ b/apps/cf-ai-backend/src/routes/ask.ts @@ -0,0 +1,35 @@ +import { GenerativeModel } from "@google/generative-ai"; +import { OpenAIEmbeddings } from "../OpenAIEmbedder"; +import { CloudflareVectorizeStore } from "@langchain/cloudflare"; +import { Request } from "@cloudflare/workers-types"; + +export async function POST(request: Request, _: CloudflareVectorizeStore, embeddings: OpenAIEmbeddings, model: GenerativeModel, env?: Env) { + const body = await request.json() as { + query: string + }; + + if (!body.query) { + return new Response(JSON.stringify({ message: "Invalid Page Content" }), { status: 400 }); + } + + const prompt = `You are an agent that answers a question based on the query. don't say 'based on the context'.\n\n Context:\n${body.query} \nAnswer this question based on the context. Question: ${body.query}\nAnswer:` + const output = await model.generateContentStream(prompt); + + const response = new Response( + new ReadableStream({ + async start(controller) { + const converter = new TextEncoder(); + for await (const chunk of output.stream) { + const chunkText = await chunk.text(); + console.log(chunkText); + const encodedChunk = converter.encode("data: " + JSON.stringify({ "response": chunkText }) + "\n\n"); + controller.enqueue(encodedChunk); + } + const doneChunk = converter.encode("data: [DONE]"); + controller.enqueue(doneChunk); + controller.close(); + } + }) + ); + return response; +} diff --git a/apps/cf-ai-backend/src/routes/query.ts b/apps/cf-ai-backend/src/routes/query.ts new file mode 100644 index 00000000..7439c44e --- /dev/null +++ b/apps/cf-ai-backend/src/routes/query.ts @@ -0,0 +1,72 @@ +import { GenerativeModel } from "@google/generative-ai"; +import { OpenAIEmbeddings } from "../OpenAIEmbedder"; +import { CloudflareVectorizeStore } from "@langchain/cloudflare"; +import { Request } from "@cloudflare/workers-types"; + +export async function GET(request: Request, _: CloudflareVectorizeStore, embeddings: OpenAIEmbeddings, model: GenerativeModel, env?: Env) { + const queryparams = new URL(request.url).searchParams; + const query = queryparams.get("q"); + const topK = parseInt(queryparams.get("topK") ?? "5"); + const user = queryparams.get("user") + + const sourcesOnly = (queryparams.get("sourcesOnly") ?? "false") + + if (!user) { + return new Response(JSON.stringify({ message: "Invalid User" }), { status: 400 }); + } + + if (!query) { + return new Response(JSON.stringify({ message: "Invalid Query" }), { status: 400 }); + } + + const filter: VectorizeVectorMetadataFilter = { + user: { + $eq: user + } + } + + const queryAsVector = await embeddings.embedQuery(query); + + const resp = await env!.VECTORIZE_INDEX.query(queryAsVector, { + topK, + filter + }); + + if (resp.count === 0) { + return new Response(JSON.stringify({ message: "No Results Found" }), { status: 400 }); + } + + const highScoreIds = resp.matches.filter(({ score }) => score > 0.3).map(({ id }) => id) + + if (sourcesOnly === "true") { + return new Response(JSON.stringify({ ids: highScoreIds }), { status: 200 }); + } + + const vec = await env!.VECTORIZE_INDEX.getByIds(highScoreIds) + + if (vec.length === 0 || !vec[0].metadata) { + return new Response(JSON.stringify({ message: "No Results Found" }), { status: 400 }); + } + + const preparedContext = vec.slice(0, 3).map(({ metadata }) => `Website title: ${metadata!.title}\nDescription: ${metadata!.description}\nURL: ${metadata!.url}\nContent: ${metadata!.text}`).join("\n\n"); + + const prompt = `You are an agent that summarizes a page based on the query. Be direct and concise, don't say 'based on the context'.\n\n Context:\n${preparedContext} \nAnswer this question based on the context. Question: ${query}\nAnswer:` + const output = await model.generateContentStream(prompt); + + const response = new Response( + new ReadableStream({ + async start(controller) { + const converter = new TextEncoder(); + for await (const chunk of output.stream) { + const chunkText = await chunk.text(); + const encodedChunk = converter.encode("data: " + JSON.stringify({ "response": chunkText }) + "\n\n"); + controller.enqueue(encodedChunk); + } + const doneChunk = converter.encode("data: [DONE]"); + controller.enqueue(doneChunk); + controller.close(); + } + }) + ); + return response; +} |