aboutsummaryrefslogtreecommitdiff
path: root/packages/tools/src/vercel/middleware.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/tools/src/vercel/middleware.ts')
-rw-r--r--packages/tools/src/vercel/middleware.ts233
1 files changed, 117 insertions, 116 deletions
diff --git a/packages/tools/src/vercel/middleware.ts b/packages/tools/src/vercel/middleware.ts
index 260718b2..a2dd77ee 100644
--- a/packages/tools/src/vercel/middleware.ts
+++ b/packages/tools/src/vercel/middleware.ts
@@ -1,8 +1,3 @@
-import type {
- LanguageModelV2CallOptions,
- LanguageModelV2Middleware,
- LanguageModelV2StreamPart,
-} from "@ai-sdk/provider"
import Supermemory from "supermemory"
import {
addConversation,
@@ -10,13 +5,15 @@ import {
} from "../conversations-client"
import { createLogger, type Logger } from "./logger"
import {
+ type LanguageModelCallOptions,
+ type LanguageModelStreamPart,
type OutputContentItem,
getLastUserMessage,
filterOutSupermemories,
} from "./util"
import { addSystemPrompt, normalizeBaseUrl } from "./memory-prompt"
-const getConversationContent = (params: LanguageModelV2CallOptions) => {
+export const getConversationContent = (params: LanguageModelCallOptions) => {
return params.prompt
.filter((msg) => msg.role !== "system" && msg.role !== "tool")
.map((msg) => {
@@ -35,8 +32,8 @@ const getConversationContent = (params: LanguageModelV2CallOptions) => {
.join("\n\n")
}
-const convertToConversationMessages = (
- params: LanguageModelV2CallOptions,
+export const convertToConversationMessages = (
+ params: LanguageModelCallOptions,
assistantResponseText: string,
): ConversationMessage[] => {
const messages: ConversationMessage[] = []
@@ -95,12 +92,12 @@ const convertToConversationMessages = (
return messages
}
-const addMemoryTool = async (
+export const saveMemoryAfterResponse = async (
client: Supermemory,
containerTag: string,
conversationId: string | undefined,
assistantResponseText: string,
- params: LanguageModelV2CallOptions,
+ params: LanguageModelCallOptions,
logger: Logger,
apiKey: string,
baseUrl: string,
@@ -156,15 +153,40 @@ const addMemoryTool = async (
}
}
-export const createSupermemoryMiddleware = (
- containerTag: string,
- apiKey: string,
- conversationId?: string,
- verbose = false,
- mode: "profile" | "query" | "full" = "profile",
- addMemory: "always" | "never" = "never",
- baseUrl?: string,
-): LanguageModelV2Middleware => {
+export interface SupermemoryMiddlewareOptions {
+ containerTag: string
+ apiKey: string
+ conversationId?: string
+ verbose?: boolean
+ mode?: "profile" | "query" | "full"
+ addMemory?: "always" | "never"
+ baseUrl?: string
+}
+
+export interface SupermemoryMiddlewareContext {
+ client: Supermemory
+ logger: Logger
+ containerTag: string
+ conversationId?: string
+ mode: "profile" | "query" | "full"
+ addMemory: "always" | "never"
+ normalizedBaseUrl: string
+ apiKey: string
+}
+
+export const createSupermemoryContext = (
+ options: SupermemoryMiddlewareOptions,
+): SupermemoryMiddlewareContext => {
+ const {
+ containerTag,
+ apiKey,
+ conversationId,
+ verbose = false,
+ mode = "profile",
+ addMemory = "never",
+ baseUrl,
+ } = options
+
const logger = createLogger(verbose)
const normalizedBaseUrl = normalizeBaseUrl(baseUrl)
@@ -176,113 +198,92 @@ export const createSupermemoryMiddleware = (
})
return {
- transformParams: async ({ params }) => {
- const userMessage = getLastUserMessage(params)
+ client,
+ logger,
+ containerTag,
+ conversationId,
+ mode,
+ addMemory,
+ normalizedBaseUrl,
+ apiKey,
+ }
+}
- if (mode !== "profile") {
- if (!userMessage) {
- logger.debug("No user message found, skipping memory search")
- return params
- }
- }
+export const transformParamsWithMemory = async (
+ params: LanguageModelCallOptions,
+ ctx: SupermemoryMiddlewareContext,
+): Promise<LanguageModelCallOptions> => {
+ const userMessage = getLastUserMessage(params)
- logger.info("Starting memory search", {
- containerTag,
- conversationId,
- mode,
- })
+ if (ctx.mode !== "profile") {
+ if (!userMessage) {
+ ctx.logger.debug("No user message found, skipping memory search")
+ return params
+ }
+ }
- const transformedParams = await addSystemPrompt(
- params,
- containerTag,
- logger,
- mode,
- normalizedBaseUrl,
- )
- return transformedParams
- },
- wrapGenerate: async ({ doGenerate, params }) => {
- const userMessage = getLastUserMessage(params)
+ ctx.logger.info("Starting memory search", {
+ containerTag: ctx.containerTag,
+ conversationId: ctx.conversationId,
+ mode: ctx.mode,
+ })
+
+ const transformedParams = await addSystemPrompt(
+ params,
+ ctx.containerTag,
+ ctx.logger,
+ ctx.mode,
+ ctx.normalizedBaseUrl,
+ )
+ return transformedParams
+}
- try {
- const result = await doGenerate()
- const assistantResponse = result.content
- const assistantResponseText = assistantResponse
- .map((content) => (content.type === "text" ? content.text : ""))
- .join("")
+export const extractAssistantResponseText = (content: unknown[]): string => {
+ return (content as Array<{ type: string; text?: string }>)
+ .map((item) => (item.type === "text" ? item.text || "" : ""))
+ .join("")
+}
- if (addMemory === "always" && userMessage && userMessage.trim()) {
- addMemoryTool(
- client,
- containerTag,
- conversationId,
- assistantResponseText,
- params,
- logger,
- apiKey,
- normalizedBaseUrl,
- )
- }
+export const createStreamTransform = (
+ ctx: SupermemoryMiddlewareContext,
+ params: LanguageModelCallOptions,
+): {
+ transform: TransformStream<LanguageModelStreamPart, LanguageModelStreamPart>
+ getGeneratedText: () => string
+} => {
+ let generatedText = ""
- return result
- } catch (error) {
- logger.error("Error generating response", {
- error: error instanceof Error ? error.message : "Unknown error",
- })
- throw error
+ const transform = new TransformStream<
+ LanguageModelStreamPart,
+ LanguageModelStreamPart
+ >({
+ transform(chunk, controller) {
+ if (chunk.type === "text-delta") {
+ generatedText += chunk.delta
}
+ controller.enqueue(chunk)
},
- wrapStream: async ({ doStream, params }) => {
+ flush: async () => {
const userMessage = getLastUserMessage(params)
- let generatedText = ""
-
- try {
- const { stream, ...rest } = await doStream()
- const transformStream = new TransformStream<
- LanguageModelV2StreamPart,
- LanguageModelV2StreamPart
- >({
- transform(chunk, controller) {
- if (chunk.type === "text-delta") {
- generatedText += chunk.delta
- }
-
- controller.enqueue(chunk)
- },
- flush: async () => {
- const content: OutputContentItem[] = []
- if (generatedText) {
- content.push({
- type: "text",
- text: generatedText,
- })
- }
-
- if (addMemory === "always" && userMessage && userMessage.trim()) {
- addMemoryTool(
- client,
- containerTag,
- conversationId,
- generatedText,
- params,
- logger,
- apiKey,
- normalizedBaseUrl,
- )
- }
- },
- })
-
- return {
- stream: stream.pipeThrough(transformStream),
- ...rest,
- }
- } catch (error) {
- logger.error("Error streaming response", {
- error: error instanceof Error ? error.message : "Unknown error",
- })
- throw error
+ if (ctx.addMemory === "always" && userMessage && userMessage.trim()) {
+ saveMemoryAfterResponse(
+ ctx.client,
+ ctx.containerTag,
+ ctx.conversationId,
+ generatedText,
+ params,
+ ctx.logger,
+ ctx.apiKey,
+ ctx.normalizedBaseUrl,
+ )
}
},
+ })
+
+ return {
+ transform,
+ getGeneratedText: () => generatedText,
}
}
+
+export { createLogger, type Logger, type OutputContentItem }