From cf1eeb86c34c92afd973defcada9841f87eeffa0 Mon Sep 17 00:00:00 2001 From: Dhravya Shah Date: Mon, 22 Jul 2024 22:44:18 -0500 Subject: added ability to use pro mode, to do complex timeseries/location searches --- apps/cf-ai-backend/src/index.ts | 71 +++++++++++++++++++++++--- apps/cf-ai-backend/src/utils/OpenAIEmbedder.ts | 2 +- apps/web/app/(dash)/chat/[chatid]/loading.tsx | 10 +++- apps/web/app/(dash)/chat/[chatid]/page.tsx | 4 +- apps/web/app/(dash)/chat/chatWindow.tsx | 48 +++++++++++++++-- apps/web/app/(dash)/header/autoBreadCrumbs.tsx | 4 +- apps/web/app/(dash)/home/history.tsx | 16 ++++-- apps/web/app/(dash)/home/page.tsx | 6 +-- apps/web/app/(dash)/home/queryinput.tsx | 14 +++-- apps/web/app/api/chat/route.ts | 3 +- apps/web/lib/searchParams.ts | 1 + packages/shared-types/index.ts | 4 ++ 12 files changed, 154 insertions(+), 29 deletions(-) diff --git a/apps/cf-ai-backend/src/index.ts b/apps/cf-ai-backend/src/index.ts index 13a33536..edfe31c2 100644 --- a/apps/cf-ai-backend/src/index.ts +++ b/apps/cf-ai-backend/src/index.ts @@ -407,6 +407,7 @@ app.post( spaces: z.string().optional(), sourcesOnly: z.string().optional().default("false"), model: z.string().optional().default("gpt-4o"), + proMode: z.string().optional().default("false"), }), ), zValidator("json", chatObj), @@ -415,6 +416,7 @@ app.post( const body = c.req.valid("json"); const sourcesOnly = query.sourcesOnly === "true"; + const proMode = query.proMode === "true"; // Return early for dumb requests if (sourcesOnly && body.sources) { @@ -422,7 +424,6 @@ app.post( } const spaces = query.spaces?.split(",") ?? [undefined]; - console.log(spaces); // Get the AI model maker and vector store const { model, store } = await initQuery(c, query.model); @@ -431,14 +432,67 @@ app.post( const filter: VectorizeVectorMetadataFilter = { [`user-${query.user}`]: 1, }; - console.log("Spaces", spaces); + + let proModeListedQueries: string[] = []; + + if (proMode) { + const addedToQuery = (await c.env.AI.run( + // @ts-ignore + "@hf/nousresearch/hermes-2-pro-mistral-7b", + { + messages: [ + { + role: "system", + content: + "You are a query enhancer. You must enhance a user's query to make it more relevant to what the user might be looking for. If there's any mention of dates like 'last summer' or 'this year', you should return 'DAY: X, MONTH: Y, YEAR: Z'. If there's any mention of locations, add that to the query too. Try to keep your responses as short as possible. Add to the user's query, don't replace it. Make sure to keep your answers short.", + }, + { role: "user", content: query.query }, + ], + tools: [ + { + type: "function", + function: { + name: "Enhance query get list", + description: + "Enhance the user's query to make it more relevant", + parameters: { + type: "object", + properties: { + listedQueries: { + type: "array", + description: "List of queries that the user has asked", + items: { + type: "string", + }, + }, + }, + required: ["Enhance query get list"], + }, + }, + }, + ], + max_tokens: 200, + }, + )) as { + response?: string; + tool_calls?: { + name: string; + arguments: { + listedQueries: string[]; + }; + }[]; + }; + + proModeListedQueries = + addedToQuery.tool_calls?.[0]?.arguments?.listedQueries ?? []; + } // Converting the query to a vector so that we can search for similar vectors - const queryAsVector = await store.embeddings.embedQuery(query.query); + const queryAsVector = await store.embeddings.embedQuery( + query.query + " " + proModeListedQueries.join(" "), + ); const responses: VectorizeMatches = { matches: [], count: 0 }; - console.log("hello world", spaces); - // SLICED to 5 to avoid too many queries for (const space of spaces.slice(0, 5)) { if (space && space.length >= 1) { @@ -506,7 +560,12 @@ app.post( const metadata = normalizedData.map((datapoint) => datapoint.metadata); - return c.json({ ids: storedContent, metadata, normalizedData }); + return c.json({ + ids: storedContent, + metadata, + normalizedData, + proModeListedQueries, + }); } } diff --git a/apps/cf-ai-backend/src/utils/OpenAIEmbedder.ts b/apps/cf-ai-backend/src/utils/OpenAIEmbedder.ts index 41d0802f..a151afc0 100644 --- a/apps/cf-ai-backend/src/utils/OpenAIEmbedder.ts +++ b/apps/cf-ai-backend/src/utils/OpenAIEmbedder.ts @@ -50,7 +50,7 @@ export class OpenAIEmbeddings { const json = zodTypeExpected.safeParse(data); if (!json.success) { - console.log(data); + console.log(JSON.stringify(data)); throw new Error("Invalid response from OpenAI: " + json.error.message); } diff --git a/apps/web/app/(dash)/chat/[chatid]/loading.tsx b/apps/web/app/(dash)/chat/[chatid]/loading.tsx index d28961a6..422adb8e 100644 --- a/apps/web/app/(dash)/chat/[chatid]/loading.tsx +++ b/apps/web/app/(dash)/chat/[chatid]/loading.tsx @@ -7,9 +7,15 @@ async function Page({ }: { searchParams: Record; }) { - const q = (searchParams?.q as string) ?? "from_loading"; + const q = (searchParams?.q as string) ?? ""; return ( - + ); } diff --git a/apps/web/app/(dash)/chat/[chatid]/page.tsx b/apps/web/app/(dash)/chat/[chatid]/page.tsx index 87fd0b19..29ffb3d8 100644 --- a/apps/web/app/(dash)/chat/[chatid]/page.tsx +++ b/apps/web/app/(dash)/chat/[chatid]/page.tsx @@ -9,7 +9,8 @@ async function Page({ params: { chatid: string }; searchParams: Record; }) { - const { firstTime, q, spaces } = chatSearchParamsCache.parse(searchParams); + const { firstTime, q, spaces, proMode } = + chatSearchParamsCache.parse(searchParams); let chat: Awaited>; @@ -31,6 +32,7 @@ async function Page({ spaces={spaces ?? []} initialChat={chat.data.length > 0 ? chat.data : undefined} threadId={params.chatid} + proMode={proMode} /> ); } diff --git a/apps/web/app/(dash)/chat/chatWindow.tsx b/apps/web/app/(dash)/chat/chatWindow.tsx index 28b99c9d..ed65bf7a 100644 --- a/apps/web/app/(dash)/chat/chatWindow.tsx +++ b/apps/web/app/(dash)/chat/chatWindow.tsx @@ -35,14 +35,19 @@ function ChatWindow({ parts: [], sources: [], }, + proModeProcessing: { + queries: [], + }, }, ], threadId, + proMode, }: { q: string; spaces: { id: number; name: string }[]; initialChat?: ChatHistory[]; threadId: string; + proMode: boolean; }) { const [layout, setLayout] = useState<"chat" | "initial">("chat"); const [chatHistory, setChatHistory] = useState(initialChat); @@ -63,13 +68,17 @@ function ChatWindow({ const router = useRouter(); - const getAnswer = async (query: string, spaces: string[]) => { + const getAnswer = async ( + query: string, + spaces: string[], + proMode: boolean = false, + ) => { if (query.trim() === "from_loading" || query.trim().length === 0) { return; } const sourcesFetch = await fetch( - `/api/chat?q=${query}&spaces=${spaces}&sourcesOnly=true&threadId=${threadId}`, + `/api/chat?q=${query}&spaces=${spaces}&sourcesOnly=true&threadId=${threadId}&proMode=${proMode}`, { method: "POST", body: JSON.stringify({ chatHistory }), @@ -91,6 +100,8 @@ function ChatWindow({ behavior: "smooth", }); + let proModeListedQueries: string[] = []; + const updateChatHistoryAndFetch = async () => { // Step 1: Update chat history with the assistant's response await new Promise((resolve) => { @@ -123,6 +134,11 @@ function ChatWindow({ ).length, })); + lastAnswer.proModeProcessing.queries = + sourcesParsed.data.proModeListedQueries ?? []; + + proModeListedQueries = lastAnswer.proModeProcessing.queries; + resolve(newChatHistory); return newChatHistory; }); @@ -130,7 +146,7 @@ function ChatWindow({ // Step 2: Fetch data from the API const resp = await fetch( - `/api/chat?q=${query}&spaces=${spaces}&threadId=${threadId}`, + `/api/chat?q=${(query += proModeListedQueries.join(" "))}&spaces=${spaces}&threadId=${threadId}`, { method: "POST", body: JSON.stringify({ chatHistory, sources: sourcesParsed.data }), @@ -181,6 +197,7 @@ function ChatWindow({ getAnswer( q, spaces.map((s) => `${s.id}`), + proMode, ); } } else { @@ -224,6 +241,28 @@ function ChatWindow({ {chat.question} + {chat.proModeProcessing?.queries?.length > 0 && ( +
+
+ Pro Mode +
+
+
+ {chat.proModeProcessing.queries.map( + (query, idx) => ( +
+ {query} +
+ ), + )} +
+
+
+ )} +
Answer
@@ -407,6 +446,9 @@ function ChatWindow({ parts: [], sources: [], }, + proModeProcessing: { + queries: [], + }, }, ]; }); diff --git a/apps/web/app/(dash)/header/autoBreadCrumbs.tsx b/apps/web/app/(dash)/header/autoBreadCrumbs.tsx index a823671c..671464ff 100644 --- a/apps/web/app/(dash)/header/autoBreadCrumbs.tsx +++ b/apps/web/app/(dash)/header/autoBreadCrumbs.tsx @@ -13,8 +13,6 @@ import React from "react"; function AutoBreadCrumbs() { const pathname = usePathname(); - console.log(pathname.split("/").filter(Boolean)); - return ( @@ -31,7 +29,7 @@ function AutoBreadCrumbs() { .filter(Boolean) .map((path, idx, paths) => ( <> - + {path.charAt(0).toUpperCase() + path.slice(1)} diff --git a/apps/web/app/(dash)/home/history.tsx b/apps/web/app/(dash)/home/history.tsx index 551197d8..922734df 100644 --- a/apps/web/app/(dash)/home/history.tsx +++ b/apps/web/app/(dash)/home/history.tsx @@ -26,9 +26,18 @@ const History = memo(() => {
    {!chatThreads_ && ( <> - - - + + + )} {chatThreads_?.map((thread) => ( @@ -36,6 +45,7 @@ const History = memo(() => { initial={{ opacity: 0, filter: "blur(1px)" }} animate={{ opacity: 1, filter: "blur(0px)" }} className="flex items-center gap-2 truncate" + key={thread.id} > {" "} diff --git a/apps/web/app/(dash)/home/page.tsx b/apps/web/app/(dash)/home/page.tsx index 53b6cd33..ebd4d84b 100644 --- a/apps/web/app/(dash)/home/page.tsx +++ b/apps/web/app/(dash)/home/page.tsx @@ -103,13 +103,12 @@ function Page({ searchParams }: { searchParams: Record }) { { + handleSubmit={async (q, spaces, proMode) => { if (q.length === 0) { toast.error("Query is required"); return; } - console.log("creating thread"); const threadid = await createChatThread(q); if (!threadid.success || !threadid.data) { @@ -117,9 +116,8 @@ function Page({ searchParams }: { searchParams: Record }) { return; } - console.log("pushing to chat"); push( - `/chat/${threadid.data}?spaces=${JSON.stringify(spaces)}&q=${q}`, + `/chat/${threadid.data}?spaces=${JSON.stringify(spaces)}&q=${q}&proMode=${proMode}`, ); }} initialSpaces={spaces} diff --git a/apps/web/app/(dash)/home/queryinput.tsx b/apps/web/app/(dash)/home/queryinput.tsx index e76e10cf..9f1e7292 100644 --- a/apps/web/app/(dash)/home/queryinput.tsx +++ b/apps/web/app/(dash)/home/queryinput.tsx @@ -1,6 +1,6 @@ "use client"; -import React, { useState } from "react"; +import React, { useEffect, useState } from "react"; import { FilterSpaces } from "./filterSpaces"; import { ArrowRightIcon } from "@repo/ui/icons"; import Image from "next/image"; @@ -20,7 +20,11 @@ function QueryInput({ }[]; initialQuery?: string; mini?: boolean; - handleSubmit: (q: string, spaces: { id: number; name: string }[]) => void; + handleSubmit: ( + q: string, + spaces: { id: number; name: string }[], + proMode: boolean, + ) => void; }) { const [q, setQ] = useState(initialQuery || ""); @@ -41,7 +45,7 @@ function QueryInput({ if (q.trim().length === 0) { return; } - handleSubmit(q, selectedSpaces); + handleSubmit(q, selectedSpaces, proMode); setQ(""); }} > @@ -58,7 +62,7 @@ function QueryInput({ if (q.trim().length === 0) { return; } - handleSubmit(q, selectedSpaces); + handleSubmit(q, selectedSpaces, proMode); setQ(""); } }} @@ -83,7 +87,7 @@ function QueryInput({ setProMode((prev) => !prev)} + onCheckedChange={(v) => setProMode(v)} id="pro-mode" about="Pro mode" /> diff --git a/apps/web/app/api/chat/route.ts b/apps/web/app/api/chat/route.ts index 004bfd3b..3b8d971b 100644 --- a/apps/web/app/api/chat/route.ts +++ b/apps/web/app/api/chat/route.ts @@ -27,6 +27,7 @@ export async function POST(req: NextRequest) { const spaces = url.searchParams.get("spaces"); const sourcesOnly = url.searchParams.get("sourcesOnly") ?? "false"; + const proMode = url.searchParams.get("proMode") === "true"; const jsonRequest = (await req.json()) as { chatHistory: ChatHistory[]; @@ -55,7 +56,7 @@ export async function POST(req: NextRequest) { const modelCompatible = await convertChatHistoryList(validated.data); const resp = await fetch( - `${process.env.BACKEND_BASE_URL}/api/chat?query=${query}&user=${session.user.id}&sourcesOnly=${sourcesOnly}&spaces=${spaces}`, + `${process.env.BACKEND_BASE_URL}/api/chat?query=${query}&user=${session.user.id}&sourcesOnly=${sourcesOnly}&spaces=${spaces}&proMode=${proMode}`, { headers: { Authorization: `Bearer ${process.env.BACKEND_SECURITY_KEY}`, diff --git a/apps/web/lib/searchParams.ts b/apps/web/lib/searchParams.ts index 2e8b1633..b90b560c 100644 --- a/apps/web/lib/searchParams.ts +++ b/apps/web/lib/searchParams.ts @@ -32,4 +32,5 @@ export const chatSearchParamsCache = createSearchParamsCache({ return valid.data; }), + proMode: parseAsBoolean.withDefault(false), }); diff --git a/packages/shared-types/index.ts b/packages/shared-types/index.ts index b3e84897..051e24a4 100644 --- a/packages/shared-types/index.ts +++ b/packages/shared-types/index.ts @@ -17,6 +17,9 @@ export const ChatHistoryZod = z.object({ sources: z.array(SourceZod), justification: z.string().optional(), }), + proModeProcessing: z.object({ + queries: z.array(z.string()), + }), }); export type ChatHistory = z.infer; @@ -77,6 +80,7 @@ export const sourcesZod = z.object({ ids: z.array(z.string()), metadata: z.array(z.any()), normalizedData: z.array(z.any()).optional(), + proModeListedQueries: z.array(z.string()).optional(), }); export type SourcesFromApi = z.infer; -- cgit v1.2.3