aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--apps/cf-ai-backend/src/index.ts184
-rw-r--r--apps/cf-ai-backend/src/types.ts2
-rw-r--r--apps/web/app/(dash)/chat/chatWindow.tsx17
-rw-r--r--apps/web/app/api/chat/route.ts24
-rw-r--r--packages/shared-types/index.ts8
5 files changed, 130 insertions, 105 deletions
diff --git a/apps/cf-ai-backend/src/index.ts b/apps/cf-ai-backend/src/index.ts
index e89d170c..224f2a42 100644
--- a/apps/cf-ai-backend/src/index.ts
+++ b/apps/cf-ai-backend/src/index.ts
@@ -1,6 +1,6 @@
import { z } from "zod";
import { Hono } from "hono";
-import { CoreMessage, generateText, streamText } from "ai";
+import { CoreMessage, streamText } from "ai";
import { chatObj, Env, vectorObj } from "./types";
import {
batchCreateChunksAndEmbeddings,
@@ -193,90 +193,102 @@ app.post(
const body = c.req.valid("json");
const sourcesOnly = query.sourcesOnly === "true";
+
+ // Return early for dumb requests
+ if (sourcesOnly && body.sources) {
+ return c.json(body.sources);
+ }
+
const spaces = query.spaces?.split(",") ?? [undefined];
// Get the AI model maker and vector store
const { model, store } = await initQuery(c, query.model);
- const filter: VectorizeVectorMetadataFilter = {
- [`user-${query.user}`]: 1,
- };
- console.log("Spaces", spaces);
-
- // Converting the query to a vector so that we can search for similar vectors
- const queryAsVector = await store.embeddings.embedQuery(query.query);
- const responses: VectorizeMatches = { matches: [], count: 0 };
-
- console.log("hello world", spaces);
-
- // SLICED to 5 to avoid too many queries
- for (const space of spaces.slice(0, 5)) {
- console.log("space", space);
- if (!space && spaces.length > 1) {
- // it's possible for space list to be [undefined] so we only add space filter conditionally
- filter[`space-${query.user}-${space}`] = 1;
+ if (!body.sources) {
+ const filter: VectorizeVectorMetadataFilter = {
+ [`user-${query.user}`]: 1,
+ };
+ console.log("Spaces", spaces);
+
+ // Converting the query to a vector so that we can search for similar vectors
+ const queryAsVector = await store.embeddings.embedQuery(query.query);
+ const responses: VectorizeMatches = { matches: [], count: 0 };
+
+ console.log("hello world", spaces);
+
+ // SLICED to 5 to avoid too many queries
+ for (const space of spaces.slice(0, 5)) {
+ console.log("space", space);
+ if (!space && spaces.length > 1) {
+ // it's possible for space list to be [undefined] so we only add space filter conditionally
+ filter[`space-${query.user}-${space}`] = 1;
+ }
+
+ // Because there's no OR operator in the filter, we have to make multiple queries
+ const resp = await c.env.VECTORIZE_INDEX.query(queryAsVector, {
+ topK: query.topK,
+ filter,
+ returnMetadata: true,
+ });
+
+ // Basically recreating the response object
+ if (resp.count > 0) {
+ responses.matches.push(...resp.matches);
+ responses.count += resp.count;
+ }
}
- // Because there's no OR operator in the filter, we have to make multiple queries
- const resp = await c.env.VECTORIZE_INDEX.query(queryAsVector, {
- topK: query.topK,
- filter,
- returnMetadata: true,
- });
-
- // Basically recreating the response object
- if (resp.count > 0) {
- responses.matches.push(...resp.matches);
- responses.count += resp.count;
- }
- }
+ const minScore = Math.min(...responses.matches.map(({ score }) => score));
+ const maxScore = Math.max(...responses.matches.map(({ score }) => score));
+
+ // We are "normalising" the scores - if all of them are on top, we want to make sure that
+ // we have a way to filter out the noise.
+ const normalizedData = responses.matches.map((data) => ({
+ ...data,
+ normalizedScore:
+ maxScore !== minScore
+ ? 1 + ((data.score - minScore) / (maxScore - minScore)) * 98
+ : 50, // If all scores are the same, set them to the middle of the scale
+ }));
+
+ let highScoreData = normalizedData.filter(
+ ({ normalizedScore }) => normalizedScore > 50,
+ );
- const minScore = Math.min(...responses.matches.map(({ score }) => score));
- const maxScore = Math.max(...responses.matches.map(({ score }) => score));
-
- // We are "normalising" the scores - if all of them are on top, we want to make sure that
- // we have a way to filter out the noise.
- const normalizedData = responses.matches.map((data) => ({
- ...data,
- normalizedScore:
- maxScore !== minScore
- ? 1 + ((data.score - minScore) / (maxScore - minScore)) * 98
- : 50, // If all scores are the same, set them to the middle of the scale
- }));
-
- let highScoreData = normalizedData.filter(
- ({ normalizedScore }) => normalizedScore > 50,
- );
+ // If the normalsation is not done properly, we have a fallback to just get the
+ // top 3 scores
+ if (highScoreData.length === 0) {
+ highScoreData = normalizedData
+ .sort((a, b) => b.score - a.score)
+ .slice(0, 3);
+ }
- // If the normalsation is not done properly, we have a fallback to just get the
- // top 3 scores
- if (highScoreData.length === 0) {
- highScoreData = normalizedData
- .sort((a, b) => b.score - a.score)
- .slice(0, 3);
- }
+ const sortedHighScoreData = highScoreData.sort(
+ (a, b) => b.normalizedScore - a.normalizedScore,
+ );
- const sortedHighScoreData = highScoreData.sort(
- (a, b) => b.normalizedScore - a.normalizedScore,
- );
+ body.sources = {
+ normalizedData,
+ };
- // So this is kinda hacky, but the frontend needs to do 2 calls to get sources and chat.
- // I think this is fine for now, but we can improve this later.
- if (sourcesOnly) {
- const idsAsStrings = sortedHighScoreData.map((dataPoint) =>
- dataPoint.id.toString(),
- );
+ // So this is kinda hacky, but the frontend needs to do 2 calls to get sources and chat.
+ // I think this is fine for now, but we can improve this later.
+ if (sourcesOnly) {
+ const idsAsStrings = sortedHighScoreData.map((dataPoint) =>
+ dataPoint.id.toString(),
+ );
- const storedContent = await Promise.all(
- idsAsStrings.map(async (id) => await c.env.KV.get(id)),
- );
+ const storedContent = await Promise.all(
+ idsAsStrings.map(async (id) => await c.env.KV.get(id)),
+ );
- const metadata = normalizedData.map((datapoint) => datapoint.metadata);
+ const metadata = normalizedData.map((datapoint) => datapoint.metadata);
- return c.json({ ids: storedContent, metadata });
+ return c.json({ ids: storedContent, metadata, normalizedData });
+ }
}
- const preparedContext = normalizedData.map(
+ const preparedContext = body.sources.normalizedData.map(
({ metadata, score, normalizedScore }) => ({
context: `Website title: ${metadata!.title}\nDescription: ${metadata!.description}\nURL: ${metadata!.url}\nContent: ${metadata!.text}`,
score,
@@ -330,20 +342,28 @@ app.delete(
},
);
-// ERROR #1 - this is the api that the editor uses, it is just a scrape off of /api/chat so you may check that out
-app.get('/api/editorai', zValidator(
- "query",
- z.object({
- context: z.string(),
- request: z.string(),
- }),
-), async (c)=> {
- const { context, request } = c.req.valid("query");
- const { model } = await initQuery(c);
+// ERROR #1 - this is the api that the editor uses, it is just a scrape off of /api/chat so you may check that out
+app.get(
+ "/api/editorai",
+ zValidator(
+ "query",
+ z.object({
+ context: z.string(),
+ request: z.string(),
+ }),
+ ),
+ async (c) => {
+ const { context, request } = c.req.valid("query");
+ const { model } = await initQuery(c);
- const response = await streamText({ model, prompt: `${request}-${context}`, maxTokens: 224 });
+ const response = await streamText({
+ model,
+ prompt: `${request}-${context}`,
+ maxTokens: 224,
+ });
- return response.toTextStreamResponse();
-})
+ return response.toTextStreamResponse();
+ },
+);
export default app;
diff --git a/apps/cf-ai-backend/src/types.ts b/apps/cf-ai-backend/src/types.ts
index 417d6320..dc97777c 100644
--- a/apps/cf-ai-backend/src/types.ts
+++ b/apps/cf-ai-backend/src/types.ts
@@ -1,3 +1,4 @@
+import { sourcesZod } from "@repo/shared-types";
import { z } from "zod";
export type Env = {
@@ -37,6 +38,7 @@ export const contentObj = z.object({
export const chatObj = z.object({
chatHistory: z.array(contentObj).optional(),
+ sources: sourcesZod.optional(),
});
export const vectorObj = z.object({
diff --git a/apps/web/app/(dash)/chat/chatWindow.tsx b/apps/web/app/(dash)/chat/chatWindow.tsx
index 8485d0b2..9a18cfe7 100644
--- a/apps/web/app/(dash)/chat/chatWindow.tsx
+++ b/apps/web/app/(dash)/chat/chatWindow.tsx
@@ -6,7 +6,7 @@ import QueryInput from "../home/queryinput";
import { cn } from "@repo/ui/lib/utils";
import { motion } from "framer-motion";
import { useRouter } from "next/navigation";
-import { ChatHistory } from "@repo/shared-types";
+import { ChatHistory, sourcesZod } from "@repo/shared-types";
import {
Accordion,
AccordionContent,
@@ -20,15 +20,10 @@ import rehypeKatex from "rehype-katex";
import rehypeHighlight from "rehype-highlight";
import { code, p } from "./markdownRenderHelpers";
import { codeLanguageSubset } from "@/lib/constants";
-import { z } from "zod";
import { toast } from "sonner";
import Link from "next/link";
import { createChatObject } from "@/app/actions/doers";
-import {
- ClipboardIcon,
- ShareIcon,
- SpeakerWaveIcon,
-} from "@heroicons/react/24/outline";
+import { ClipboardIcon } from "@heroicons/react/24/outline";
import { SendIcon } from "lucide-react";
function ChatWindow({
@@ -83,11 +78,6 @@ function ChatWindow({
// TODO: handle this properly
const sources = await sourcesFetch.json();
- const sourcesZod = z.object({
- ids: z.array(z.string()),
- metadata: z.array(z.any()),
- });
-
const sourcesParsed = sourcesZod.safeParse(sources);
if (!sourcesParsed.success) {
@@ -100,7 +90,6 @@ function ChatWindow({
behavior: "smooth",
});
- // Assuming this is part of a larger function within a React component
const updateChatHistoryAndFetch = async () => {
// Step 1: Update chat history with the assistant's response
await new Promise((resolve) => {
@@ -143,7 +132,7 @@ function ChatWindow({
`/api/chat?q=${query}&spaces=${spaces}&threadId=${threadId}`,
{
method: "POST",
- body: JSON.stringify({ chatHistory }),
+ body: JSON.stringify({ chatHistory, sources: sourcesParsed.data }),
},
);
diff --git a/apps/web/app/api/chat/route.ts b/apps/web/app/api/chat/route.ts
index d0e53066..d1730baa 100644
--- a/apps/web/app/api/chat/route.ts
+++ b/apps/web/app/api/chat/route.ts
@@ -1,5 +1,10 @@
import { type NextRequest } from "next/server";
-import { ChatHistoryZod, convertChatHistoryList } from "@repo/shared-types";
+import {
+ ChatHistory,
+ ChatHistoryZod,
+ convertChatHistoryList,
+ SourcesFromApi,
+} from "@repo/shared-types";
import { ensureAuth } from "../ensureAuth";
import { z } from "zod";
@@ -23,7 +28,11 @@ export async function POST(req: NextRequest) {
const sourcesOnly = url.searchParams.get("sourcesOnly") ?? "false";
- const chatHistory = await req.json();
+ const jsonRequest = (await req.json()) as {
+ chatHistory: ChatHistory[];
+ sources: SourcesFromApi[] | undefined;
+ };
+ const { chatHistory, sources } = jsonRequest;
if (!query || query.trim.length < 0) {
return new Response(JSON.stringify({ message: "Invalid query" }), {
@@ -31,9 +40,7 @@ export async function POST(req: NextRequest) {
});
}
- const validated = z
- .object({ chatHistory: z.array(ChatHistoryZod) })
- .safeParse(chatHistory ?? []);
+ const validated = z.array(ChatHistoryZod).safeParse(chatHistory ?? []);
if (!validated.success) {
return new Response(
@@ -45,9 +52,7 @@ export async function POST(req: NextRequest) {
);
}
- const modelCompatible = await convertChatHistoryList(
- validated.data.chatHistory,
- );
+ const modelCompatible = await convertChatHistoryList(validated.data);
const resp = await fetch(
`${process.env.BACKEND_BASE_URL}/api/chat?query=${query}&user=${session.user.id}&sourcesOnly=${sourcesOnly}&spaces=${spaces}`,
@@ -59,12 +64,13 @@ export async function POST(req: NextRequest) {
method: "POST",
body: JSON.stringify({
chatHistory: modelCompatible,
+ sources,
}),
},
);
if (sourcesOnly == "true") {
- const data = await resp.json();
+ const data = (await resp.json()) as SourcesFromApi;
return new Response(JSON.stringify(data), { status: 200 });
}
diff --git a/packages/shared-types/index.ts b/packages/shared-types/index.ts
index 318684b7..9d5caa40 100644
--- a/packages/shared-types/index.ts
+++ b/packages/shared-types/index.ts
@@ -59,3 +59,11 @@ export function convertChatHistoryList(
return convertedChats;
}
+
+export const sourcesZod = z.object({
+ ids: z.array(z.string()),
+ metadata: z.array(z.any()),
+ normalizedData: z.array(z.any()).optional(),
+});
+
+export type SourcesFromApi = z.infer<typeof sourcesZod>;