diff options
| author | MaheshtheDev <[email protected]> | 2026-01-20 01:30:43 +0000 |
|---|---|---|
| committer | MaheshtheDev <[email protected]> | 2026-01-20 01:30:43 +0000 |
| commit | 32a7eff3af6e82fd3d2b419ecd016ab144d4c508 (patch) | |
| tree | bec738b44bc7f35f7f34083c1462ed5e8e9cc856 /packages/tools/src | |
| parent | remove bun lock file (#686) (diff) | |
| download | supermemory-32a7eff3af6e82fd3d2b419ecd016ab144d4c508.tar.xz supermemory-32a7eff3af6e82fd3d2b419ecd016ab144d4c508.zip | |
fix(tools): multi step agent prompt caching (#685)01-19-fix_tools_multi_step_agent_prompt_caching
Diffstat (limited to 'packages/tools/src')
| -rw-r--r-- | packages/tools/src/vercel/memory-prompt.ts | 144 | ||||
| -rw-r--r-- | packages/tools/src/vercel/middleware.ts | 129 | ||||
| -rw-r--r-- | packages/tools/src/vercel/util.ts | 24 |
3 files changed, 205 insertions, 92 deletions
diff --git a/packages/tools/src/vercel/memory-prompt.ts b/packages/tools/src/vercel/memory-prompt.ts index 3dfc203f..185a3c41 100644 --- a/packages/tools/src/vercel/memory-prompt.ts +++ b/packages/tools/src/vercel/memory-prompt.ts @@ -93,33 +93,38 @@ const supermemoryProfileSearch = async ( } } -export const addSystemPrompt = async ( - params: LanguageModelCallOptions, - containerTag: string, - logger: Logger, - mode: "profile" | "query" | "full", - baseUrl: string, - apiKey: string, - promptTemplate: PromptTemplate = defaultPromptTemplate, -): Promise<LanguageModelCallOptions> => { - const systemPromptExists = params.prompt.some( - (prompt) => prompt.role === "system", - ) +/** + * Options for building memories text. + */ +export interface BuildMemoriesTextOptions { + containerTag: string + queryText: string + mode: "profile" | "query" | "full" + baseUrl: string + apiKey: string + logger: Logger + promptTemplate?: PromptTemplate +} - const queryText = - mode !== "profile" - ? params.prompt - .slice() - .reverse() - .find((prompt: { role: string }) => prompt.role === "user") - ?.content?.filter( - (content: { type: string }) => content.type === "text", - ) - ?.map((content: { type: string; text: string }) => - content.type === "text" ? content.text : "", - ) - ?.join(" ") || "" - : "" +/** + * Fetches memories from the API, deduplicates them, and formats them into + * the final string to be injected into the system prompt. + * + * @param options - Configuration for building memories text + * @returns The final formatted memories string ready for injection + */ +export const buildMemoriesText = async ( + options: BuildMemoriesTextOptions, +): Promise<string> => { + const { + containerTag, + queryText, + mode, + baseUrl, + apiKey, + logger, + promptTemplate = defaultPromptTemplate, + } = options const memoriesResponse = await supermemoryProfileSearch( containerTag, @@ -191,6 +196,27 @@ export const addSystemPrompt = async ( }) } + return memories +} + +/** + * Injects memories string into params by appending to existing system prompt + * or creating a new one. Pure function - does not mutate the original params. + * + * @param params - The language model call options + * @param memories - The formatted memories string to inject + * @param logger - Logger for debug output + * @returns New params with memories injected into the system prompt + */ +export const injectMemoriesIntoParams = ( + params: LanguageModelCallOptions, + memories: string, + logger: Logger, +): LanguageModelCallOptions => { + const systemPromptExists = params.prompt.some( + (prompt) => prompt.role === "system", + ) + if (systemPromptExists) { logger.debug("Added memories to existing system prompt") // biome-ignore lint/suspicious/noExplicitAny: Union type compatibility between V2 and V3 prompt types @@ -212,3 +238,69 @@ export const addSystemPrompt = async ( ] as any return { ...params, prompt: newPrompt } as LanguageModelCallOptions } + +/** + * Extracts the query text from params based on mode. + * For "profile" mode, returns empty string (no query needed). + * For "query" or "full" mode, extracts the last user message text. + * + * @param params - The language model call options + * @param mode - The memory retrieval mode + * @returns The query text for memory search + */ +export const extractQueryText = ( + params: LanguageModelCallOptions, + mode: "profile" | "query" | "full", +): string => { + if (mode === "profile") { + return "" + } + + const userMessage = params.prompt + .slice() + .reverse() + .find((prompt: { role: string }) => prompt.role === "user") + + const content = userMessage?.content + if (!content) return "" + + if (typeof content === "string") { + return content + } + + // biome-ignore lint/suspicious/noExplicitAny: Union type compatibility between V2 and V3 + return (content as any[]) + .filter((part) => part.type === "text") + .map((part) => part.text || "") + .join(" ") +} + +/** + * Adds memories to the system prompt by fetching from API and injecting. + * This is the original combined function, now implemented via helpers. + * + * @deprecated Prefer using buildMemoriesText + injectMemoriesIntoParams for caching support + */ +export const addSystemPrompt = async ( + params: LanguageModelCallOptions, + containerTag: string, + logger: Logger, + mode: "profile" | "query" | "full", + baseUrl: string, + apiKey: string, + promptTemplate: PromptTemplate = defaultPromptTemplate, +): Promise<LanguageModelCallOptions> => { + const queryText = extractQueryText(params, mode) + + const memories = await buildMemoriesText({ + containerTag, + queryText, + mode, + baseUrl, + apiKey, + logger, + promptTemplate, + }) + + return injectMemoriesIntoParams(params, memories, logger) +} diff --git a/packages/tools/src/vercel/middleware.ts b/packages/tools/src/vercel/middleware.ts index 8336b397..e3a9ac57 100644 --- a/packages/tools/src/vercel/middleware.ts +++ b/packages/tools/src/vercel/middleware.ts @@ -6,18 +6,18 @@ import { import { createLogger, type Logger } from "./logger" import { type LanguageModelCallOptions, - type LanguageModelStreamPart, - type OutputContentItem, getLastUserMessage, filterOutSupermemories, } from "./util" import { - addSystemPrompt, + buildMemoriesText, + extractQueryText, + injectMemoriesIntoParams, normalizeBaseUrl, type PromptTemplate, } from "./memory-prompt" -export const getConversationContent = (params: LanguageModelCallOptions) => { +const getConversationContent = (params: LanguageModelCallOptions) => { return params.prompt .filter((msg) => msg.role !== "system" && msg.role !== "tool") .map((msg) => { @@ -36,7 +36,7 @@ export const getConversationContent = (params: LanguageModelCallOptions) => { .join("\n\n") } -export const convertToConversationMessages = ( +const convertToConversationMessages = ( params: LanguageModelCallOptions, assistantResponseText: string, ): ConversationMessage[] => { @@ -160,7 +160,7 @@ export const saveMemoryAfterResponse = async ( /** * Configuration options for the Supermemory middleware. */ -export interface SupermemoryMiddlewareOptions { +interface SupermemoryMiddlewareOptions { /** Container tag/identifier for memory search (e.g., user ID, project ID) */ containerTag: string /** Supermemory API key */ @@ -188,7 +188,12 @@ export interface SupermemoryMiddlewareOptions { promptTemplate?: PromptTemplate } -export interface SupermemoryMiddlewareContext { +/** + * Cached memories string for a user turn. + */ +type MemoryCache = string + +interface SupermemoryMiddlewareContext { client: Supermemory logger: Logger containerTag: string @@ -198,6 +203,11 @@ export interface SupermemoryMiddlewareContext { normalizedBaseUrl: string apiKey: string promptTemplate?: PromptTemplate + /** + * Per-turn memory cache map. Stores the injected memories string for each + * user turn (keyed by turnKey) to avoid redundant API calls during tool-call + */ + memoryCache: Map<string, MemoryCache> } export const createSupermemoryContext = ( @@ -234,9 +244,30 @@ export const createSupermemoryContext = ( normalizedBaseUrl, apiKey, promptTemplate, + memoryCache: new Map<string, MemoryCache>(), } } +/** + * Generates a cache key for the current turn based on context and user message. + * Normalizes the user message by trimming and collapsing whitespace. + */ +const makeTurnKey = ( + ctx: SupermemoryMiddlewareContext, + userMessage: string, +): string => { + const normalizedMessage = userMessage.trim().replace(/\s+/g, " ") + return `${ctx.containerTag}:${ctx.conversationId || ""}:${ctx.mode}:${normalizedMessage}` +} + +/** + * Checks if this is a new user turn (last message is from user) + */ +const isNewUserTurn = (params: LanguageModelCallOptions): boolean => { + const lastMessage = params.prompt.at(-1) + return lastMessage?.role === "user" +} + export const transformParamsWithMemory = async ( params: LanguageModelCallOptions, ctx: SupermemoryMiddlewareContext, @@ -250,22 +281,42 @@ export const transformParamsWithMemory = async ( } } + const turnKey = makeTurnKey(ctx, userMessage || "") + const isNewTurn = isNewUserTurn(params) + + // Check if we can use cached memories + const cachedMemories = ctx.memoryCache.get(turnKey) + if (!isNewTurn && cachedMemories) { + ctx.logger.debug("Using cached memories: ", { + turnKey, + }) + return injectMemoriesIntoParams(params, cachedMemories, ctx.logger) + } + ctx.logger.info("Starting memory search", { containerTag: ctx.containerTag, conversationId: ctx.conversationId, mode: ctx.mode, + isNewTurn, + cacheHit: false, }) - const transformedParams = await addSystemPrompt( - params, - ctx.containerTag, - ctx.logger, - ctx.mode, - ctx.normalizedBaseUrl, - ctx.apiKey, - ctx.promptTemplate, - ) - return transformedParams + const queryText = extractQueryText(params, ctx.mode) + + const memories = await buildMemoriesText({ + containerTag: ctx.containerTag, + queryText, + mode: ctx.mode, + baseUrl: ctx.normalizedBaseUrl, + apiKey: ctx.apiKey, + logger: ctx.logger, + promptTemplate: ctx.promptTemplate, + }) + + ctx.memoryCache.set(turnKey, memories) + ctx.logger.debug("Cached memories for turn", { turnKey }) + + return injectMemoriesIntoParams(params, memories, ctx.logger) } export const extractAssistantResponseText = (content: unknown[]): string => { @@ -273,47 +324,3 @@ export const extractAssistantResponseText = (content: unknown[]): string => { .map((item) => (item.type === "text" ? item.text || "" : "")) .join("") } - -export const createStreamTransform = ( - ctx: SupermemoryMiddlewareContext, - params: LanguageModelCallOptions, -): { - transform: TransformStream<LanguageModelStreamPart, LanguageModelStreamPart> - getGeneratedText: () => string -} => { - let generatedText = "" - - const transform = new TransformStream< - LanguageModelStreamPart, - LanguageModelStreamPart - >({ - transform(chunk, controller) { - if (chunk.type === "text-delta") { - generatedText += chunk.delta - } - controller.enqueue(chunk) - }, - flush: async () => { - const userMessage = getLastUserMessage(params) - if (ctx.addMemory === "always" && userMessage && userMessage.trim()) { - saveMemoryAfterResponse( - ctx.client, - ctx.containerTag, - ctx.conversationId, - generatedText, - params, - ctx.logger, - ctx.apiKey, - ctx.normalizedBaseUrl, - ) - } - }, - }) - - return { - transform, - getGeneratedText: () => generatedText, - } -} - -export { createLogger, type Logger, type OutputContentItem } diff --git a/packages/tools/src/vercel/util.ts b/packages/tools/src/vercel/util.ts index eec29859..01b655e1 100644 --- a/packages/tools/src/vercel/util.ts +++ b/packages/tools/src/vercel/util.ts @@ -100,16 +100,30 @@ export function convertProfileToMarkdown(data: ProfileMarkdownData): string { return sections.join("\n\n") } -export const getLastUserMessage = (params: LanguageModelCallOptions) => { +export const getLastUserMessage = ( + params: LanguageModelCallOptions, +): string | undefined => { const lastUserMessage = params.prompt .slice() .reverse() .find((prompt: LanguageModelMessage) => prompt.role === "user") - const memories = lastUserMessage?.content - .filter((content) => content.type === "text") - .map((content) => (content as { type: "text"; text: string }).text) + + if (!lastUserMessage) { + return undefined + } + + const content = lastUserMessage.content + + // Handle string content directly + if (typeof content === "string") { + return content + } + + // Handle array content - extract text parts + return content + .filter((part) => part.type === "text") + .map((part) => (part as { type: "text"; text: string }).text) .join(" ") - return memories } export const filterOutSupermemories = (content: string) => { |