diff options
Diffstat (limited to 'packages/tools/src/vercel/middleware.ts')
| -rw-r--r-- | packages/tools/src/vercel/middleware.ts | 233 |
1 files changed, 117 insertions, 116 deletions
diff --git a/packages/tools/src/vercel/middleware.ts b/packages/tools/src/vercel/middleware.ts index 260718b2..a2dd77ee 100644 --- a/packages/tools/src/vercel/middleware.ts +++ b/packages/tools/src/vercel/middleware.ts @@ -1,8 +1,3 @@ -import type { - LanguageModelV2CallOptions, - LanguageModelV2Middleware, - LanguageModelV2StreamPart, -} from "@ai-sdk/provider" import Supermemory from "supermemory" import { addConversation, @@ -10,13 +5,15 @@ import { } from "../conversations-client" import { createLogger, type Logger } from "./logger" import { + type LanguageModelCallOptions, + type LanguageModelStreamPart, type OutputContentItem, getLastUserMessage, filterOutSupermemories, } from "./util" import { addSystemPrompt, normalizeBaseUrl } from "./memory-prompt" -const getConversationContent = (params: LanguageModelV2CallOptions) => { +export const getConversationContent = (params: LanguageModelCallOptions) => { return params.prompt .filter((msg) => msg.role !== "system" && msg.role !== "tool") .map((msg) => { @@ -35,8 +32,8 @@ const getConversationContent = (params: LanguageModelV2CallOptions) => { .join("\n\n") } -const convertToConversationMessages = ( - params: LanguageModelV2CallOptions, +export const convertToConversationMessages = ( + params: LanguageModelCallOptions, assistantResponseText: string, ): ConversationMessage[] => { const messages: ConversationMessage[] = [] @@ -95,12 +92,12 @@ const convertToConversationMessages = ( return messages } -const addMemoryTool = async ( +export const saveMemoryAfterResponse = async ( client: Supermemory, containerTag: string, conversationId: string | undefined, assistantResponseText: string, - params: LanguageModelV2CallOptions, + params: LanguageModelCallOptions, logger: Logger, apiKey: string, baseUrl: string, @@ -156,15 +153,40 @@ const addMemoryTool = async ( } } -export const createSupermemoryMiddleware = ( - containerTag: string, - apiKey: string, - conversationId?: string, - verbose = false, - mode: "profile" | "query" | "full" = "profile", - addMemory: "always" | "never" = "never", - baseUrl?: string, -): LanguageModelV2Middleware => { +export interface SupermemoryMiddlewareOptions { + containerTag: string + apiKey: string + conversationId?: string + verbose?: boolean + mode?: "profile" | "query" | "full" + addMemory?: "always" | "never" + baseUrl?: string +} + +export interface SupermemoryMiddlewareContext { + client: Supermemory + logger: Logger + containerTag: string + conversationId?: string + mode: "profile" | "query" | "full" + addMemory: "always" | "never" + normalizedBaseUrl: string + apiKey: string +} + +export const createSupermemoryContext = ( + options: SupermemoryMiddlewareOptions, +): SupermemoryMiddlewareContext => { + const { + containerTag, + apiKey, + conversationId, + verbose = false, + mode = "profile", + addMemory = "never", + baseUrl, + } = options + const logger = createLogger(verbose) const normalizedBaseUrl = normalizeBaseUrl(baseUrl) @@ -176,113 +198,92 @@ export const createSupermemoryMiddleware = ( }) return { - transformParams: async ({ params }) => { - const userMessage = getLastUserMessage(params) + client, + logger, + containerTag, + conversationId, + mode, + addMemory, + normalizedBaseUrl, + apiKey, + } +} - if (mode !== "profile") { - if (!userMessage) { - logger.debug("No user message found, skipping memory search") - return params - } - } +export const transformParamsWithMemory = async ( + params: LanguageModelCallOptions, + ctx: SupermemoryMiddlewareContext, +): Promise<LanguageModelCallOptions> => { + const userMessage = getLastUserMessage(params) - logger.info("Starting memory search", { - containerTag, - conversationId, - mode, - }) + if (ctx.mode !== "profile") { + if (!userMessage) { + ctx.logger.debug("No user message found, skipping memory search") + return params + } + } - const transformedParams = await addSystemPrompt( - params, - containerTag, - logger, - mode, - normalizedBaseUrl, - ) - return transformedParams - }, - wrapGenerate: async ({ doGenerate, params }) => { - const userMessage = getLastUserMessage(params) + ctx.logger.info("Starting memory search", { + containerTag: ctx.containerTag, + conversationId: ctx.conversationId, + mode: ctx.mode, + }) + + const transformedParams = await addSystemPrompt( + params, + ctx.containerTag, + ctx.logger, + ctx.mode, + ctx.normalizedBaseUrl, + ) + return transformedParams +} - try { - const result = await doGenerate() - const assistantResponse = result.content - const assistantResponseText = assistantResponse - .map((content) => (content.type === "text" ? content.text : "")) - .join("") +export const extractAssistantResponseText = (content: unknown[]): string => { + return (content as Array<{ type: string; text?: string }>) + .map((item) => (item.type === "text" ? item.text || "" : "")) + .join("") +} - if (addMemory === "always" && userMessage && userMessage.trim()) { - addMemoryTool( - client, - containerTag, - conversationId, - assistantResponseText, - params, - logger, - apiKey, - normalizedBaseUrl, - ) - } +export const createStreamTransform = ( + ctx: SupermemoryMiddlewareContext, + params: LanguageModelCallOptions, +): { + transform: TransformStream<LanguageModelStreamPart, LanguageModelStreamPart> + getGeneratedText: () => string +} => { + let generatedText = "" - return result - } catch (error) { - logger.error("Error generating response", { - error: error instanceof Error ? error.message : "Unknown error", - }) - throw error + const transform = new TransformStream< + LanguageModelStreamPart, + LanguageModelStreamPart + >({ + transform(chunk, controller) { + if (chunk.type === "text-delta") { + generatedText += chunk.delta } + controller.enqueue(chunk) }, - wrapStream: async ({ doStream, params }) => { + flush: async () => { const userMessage = getLastUserMessage(params) - let generatedText = "" - - try { - const { stream, ...rest } = await doStream() - const transformStream = new TransformStream< - LanguageModelV2StreamPart, - LanguageModelV2StreamPart - >({ - transform(chunk, controller) { - if (chunk.type === "text-delta") { - generatedText += chunk.delta - } - - controller.enqueue(chunk) - }, - flush: async () => { - const content: OutputContentItem[] = [] - if (generatedText) { - content.push({ - type: "text", - text: generatedText, - }) - } - - if (addMemory === "always" && userMessage && userMessage.trim()) { - addMemoryTool( - client, - containerTag, - conversationId, - generatedText, - params, - logger, - apiKey, - normalizedBaseUrl, - ) - } - }, - }) - - return { - stream: stream.pipeThrough(transformStream), - ...rest, - } - } catch (error) { - logger.error("Error streaming response", { - error: error instanceof Error ? error.message : "Unknown error", - }) - throw error + if (ctx.addMemory === "always" && userMessage && userMessage.trim()) { + saveMemoryAfterResponse( + ctx.client, + ctx.containerTag, + ctx.conversationId, + generatedText, + params, + ctx.logger, + ctx.apiKey, + ctx.normalizedBaseUrl, + ) } }, + }) + + return { + transform, + getGeneratedText: () => generatedText, } } + +export { createLogger, type Logger, type OutputContentItem } |