diff options
Diffstat (limited to 'apps/cf-ai-backend/src/index.ts')
| -rw-r--r-- | apps/cf-ai-backend/src/index.ts | 157 |
1 files changed, 129 insertions, 28 deletions
diff --git a/apps/cf-ai-backend/src/index.ts b/apps/cf-ai-backend/src/index.ts index 19770dec..effdf517 100644 --- a/apps/cf-ai-backend/src/index.ts +++ b/apps/cf-ai-backend/src/index.ts @@ -1,6 +1,6 @@ import { z } from "zod"; import { Hono } from "hono"; -import { CoreMessage, streamText } from "ai"; +import { CoreMessage, generateText, streamText } from "ai"; import { chatObj, Env, vectorObj } from "./types"; import { batchCreateChunksAndEmbeddings, @@ -14,9 +14,17 @@ import { bearerAuth } from "hono/bearer-auth"; import { zValidator } from "@hono/zod-validator"; import chunkText from "./utils/chonker"; import { systemPrompt, template } from "./prompts/prompt1"; +import { swaggerUI } from "@hono/swagger-ui"; const app = new Hono<{ Bindings: Env }>(); +app.get( + "/ui", + swaggerUI({ + url: "/doc", + }), +); + // ------- MIDDLEWARES ------- app.use("*", poweredBy()); app.use("*", timing()); @@ -31,6 +39,17 @@ app.use("/api/", async (c, next) => { }); // ------- MIDDLEWARES END ------- +const fileSchema = z + .instanceof(File) + .refine( + (file) => file.size <= 10 * 1024 * 1024, + "File size should be less than 10MB", + ) // Validate file size + .refine( + (file) => ["image/jpeg", "image/png", "image/gif"].includes(file.type), + "Invalid file type", + ); // Validate file type + app.get("/", (c) => { return c.text("Supermemory backend API is running!"); }); @@ -54,6 +73,82 @@ app.post("/api/add", zValidator("json", vectorObj), async (c) => { return c.json({ status: "ok" }); }); +app.post( + "/api/add-with-image", + zValidator( + "form", + z.object({ + images: z + .array(fileSchema) + .min(1, "At least one image is required") + .optional(), + "images[]": z + .array(fileSchema) + .min(1, "At least one image is required") + .optional(), + text: z.string().optional(), + spaces: z.array(z.string()).optional(), + url: z.string(), + user: z.string(), + }), + (c) => { + console.log(c); + }, + ), + async (c) => { + const body = c.req.valid("form"); + + const { store } = await initQuery(c); + + if (!(body.images || body["images[]"])) { + return c.json({ status: "error", message: "No images found" }, 400); + } + + const imagePromises = (body.images ?? body["images[]"]).map( + async (image) => { + const buffer = await image.arrayBuffer(); + const input = { + image: [...new Uint8Array(buffer)], + prompt: + "What's in this image? caption everything you see in great detail. If it has text, do an OCR and extract all of it.", + max_tokens: 1024, + }; + const response = await c.env.AI.run( + "@cf/llava-hf/llava-1.5-7b-hf", + input, + ); + console.log(response.description); + return response.description; + }, + ); + + const imageDescriptions = await Promise.all(imagePromises); + + await batchCreateChunksAndEmbeddings({ + store, + body: { + url: body.url, + user: body.user, + type: "image", + description: + imageDescriptions.length > 1 + ? `A group of ${imageDescriptions.length} images on ${body.url}` + : imageDescriptions[0], + spaces: body.spaces, + pageContent: imageDescriptions.join("\n"), + title: "Image content from the web", + }, + chunks: [ + imageDescriptions, + ...(body.text ? chunkText(body.text, 1536) : []), + ].flat(), + context: c, + }); + + return c.json({ status: "ok" }); + }, +); + app.get( "/api/ask", zValidator( @@ -85,8 +180,8 @@ app.post( "query", z.object({ query: z.string(), - topK: z.number().optional().default(10), user: z.string(), + topK: z.number().optional().default(10), spaces: z.string().optional(), sourcesOnly: z.string().optional().default("false"), model: z.string().optional().default("gpt-4o"), @@ -97,30 +192,29 @@ app.post( const query = c.req.valid("query"); const body = c.req.valid("json"); - if (body.chatHistory) { - body.chatHistory = body.chatHistory.map((i) => ({ - ...i, - content: i.parts.length > 0 ? i.parts.join(" ") : i.content, - })); - } - const sourcesOnly = query.sourcesOnly === "true"; - const spaces = query.spaces?.split(",") || [undefined]; + const spaces = query.spaces?.split(",") ?? [undefined]; // Get the AI model maker and vector store const { model, store } = await initQuery(c, query.model); - const filter: VectorizeVectorMetadataFilter = { user: query.user }; + const filter: VectorizeVectorMetadataFilter = { + [`user-${query.user}`]: 1, + }; + console.log("Spaces", spaces); // Converting the query to a vector so that we can search for similar vectors const queryAsVector = await store.embeddings.embedQuery(query.query); const responses: VectorizeMatches = { matches: [], count: 0 }; + console.log("hello world", spaces); + // SLICED to 5 to avoid too many queries for (const space of spaces.slice(0, 5)) { - if (space !== undefined) { + console.log("space", space); + if (!space && spaces.length > 1) { // it's possible for space list to be [undefined] so we only add space filter conditionally - filter.space = space; + filter[`space-${query.user}-${space}`] = 1; } // Because there's no OR operator in the filter, we have to make multiple queries @@ -173,29 +267,20 @@ app.post( dataPoint.id.toString(), ); - // We are getting the content ID back, so that the frontend can show the actual sources properly. - // it IS a lot of DB calls, i completely agree. - // TODO: return metadata value here, so that the frontend doesn't have to re-fetch anything. const storedContent = await Promise.all( idsAsStrings.map(async (id) => await c.env.KV.get(id)), ); - return c.json({ ids: storedContent }); - } - - const vec = responses.matches.map((data) => ({ metadata: data.metadata })); + const metadata = normalizedData.map((datapoint) => datapoint.metadata); - const vecWithScores = vec.map((v, i) => ({ - ...v, - score: sortedHighScoreData[i].score, - normalisedScore: sortedHighScoreData[i].normalizedScore, - })); + return c.json({ ids: storedContent, metadata }); + } - const preparedContext = vecWithScores.map( - ({ metadata, score, normalisedScore }) => ({ + const preparedContext = normalizedData.map( + ({ metadata, score, normalizedScore }) => ({ context: `Website title: ${metadata!.title}\nDescription: ${metadata!.description}\nURL: ${metadata!.url}\nContent: ${metadata!.text}`, score, - normalisedScore, + normalizedScore, }), ); @@ -245,4 +330,20 @@ app.delete( }, ); +app.get('/api/editorai', zValidator( + "query", + z.object({ + context: z.string(), + request: z.string(), + }), +), async (c)=> { + const { context, request } = c.req.valid("query"); + + const { model } = await initQuery(c); + + const {text} = await generateText({ model, prompt: `${request}-${context}`, maxTokens: 224 }); + + return c.json({completion: text}); +}) + export default app; |