diff options
Diffstat (limited to 'packages/tools/src/openai/middleware.ts')
| -rw-r--r-- | packages/tools/src/openai/middleware.ts | 181 |
1 files changed, 151 insertions, 30 deletions
diff --git a/packages/tools/src/openai/middleware.ts b/packages/tools/src/openai/middleware.ts index 3a851b31..29b66f70 100644 --- a/packages/tools/src/openai/middleware.ts +++ b/packages/tools/src/openai/middleware.ts @@ -1,13 +1,21 @@ import type OpenAI from "openai" import Supermemory from "supermemory" +import { addConversation } from "../conversations-client" import { createLogger, type Logger } from "../vercel/logger" import { convertProfileToMarkdown } from "../vercel/util" +const normalizeBaseUrl = (url?: string): string => { + const defaultUrl = "https://api.supermemory.ai" + if (!url) return defaultUrl + return url.endsWith("/") ? url.slice(0, -1) : url +} + export interface OpenAIMiddlewareOptions { conversationId?: string verbose?: boolean mode?: "profile" | "query" | "full" addMemory?: "always" | "never" + baseUrl?: string } interface SupermemoryProfileSearch { @@ -78,6 +86,7 @@ const getLastUserMessage = ( const supermemoryProfileSearch = async ( containerTag: string, queryText: string, + baseUrl: string, ): Promise<SupermemoryProfileSearch> => { const payload = queryText ? JSON.stringify({ @@ -89,7 +98,7 @@ const supermemoryProfileSearch = async ( }) try { - const response = await fetch("https://api.supermemory.ai/v4/profile", { + const response = await fetch(`${baseUrl}/v4/profile`, { method: "POST", headers: { "Content-Type": "application/json", @@ -147,18 +156,61 @@ const addSystemPrompt = async ( containerTag: string, logger: Logger, mode: "profile" | "query" | "full", + baseUrl: string, ) => { const systemPromptExists = messages.some((msg) => msg.role === "system") const queryText = mode !== "profile" ? getLastUserMessage(messages) : "" - const memories = await searchAndFormatMemories( + const memoriesResponse = await supermemoryProfileSearch( + containerTag, queryText, + baseUrl, + ) + + const memoryCountStatic = memoriesResponse.profile.static?.length || 0 + const memoryCountDynamic = memoriesResponse.profile.dynamic?.length || 0 + + logger.info("Memory search completed for chat API", { containerTag, - logger, + memoryCountStatic, + memoryCountDynamic, + queryText: + queryText.substring(0, 100) + (queryText.length > 100 ? "..." : ""), mode, - "chat", - ) + }) + + const profileData = + mode !== "query" + ? convertProfileToMarkdown({ + profile: { + static: memoriesResponse.profile.static?.map((item) => item.memory), + dynamic: memoriesResponse.profile.dynamic?.map( + (item) => item.memory, + ), + }, + searchResults: { + results: memoriesResponse.searchResults.results.map((item) => ({ + memory: item.memory, + })) as [{ memory: string }], + }, + }) + : "" + const searchResultsMemories = + mode !== "profile" + ? `Search results for user's recent message: \n${memoriesResponse.searchResults.results + .map((result) => `- ${result.memory}`) + .join("\n")}` + : "" + + const memories = `${profileData}\n${searchResultsMemories}`.trim() + + if (memories) { + logger.debug("Memory content preview for chat API", { + content: memories, + fullLength: memories.length, + }) + } if (systemPromptExists) { logger.debug("Added memories to existing system prompt") @@ -215,11 +267,17 @@ const getConversationContent = ( * Saves the provided content as a memory with the specified container tag and * optional custom ID. Logs success or failure information for debugging. * + * If customId starts with "conversation:" and messages are provided, uses the + * /v4/conversations endpoint with structured messages instead of the memories endpoint. + * * @param client - SuperMemory client instance * @param containerTag - The container tag/identifier for the memory - * @param content - The content to save as a memory - * @param customId - Optional custom ID for the memory (e.g., conversation ID) + * @param content - The content to save as a memory (used for fallback) + * @param customId - Optional custom ID for the memory (e.g., conversation:456) * @param logger - Logger instance for debugging and info output + * @param messages - Optional OpenAI messages array (for conversation endpoint) + * @param apiKey - API key for direct conversation endpoint calls + * @param baseUrl - Base URL for API calls * @returns Promise that resolves when memory is saved (or fails silently) * * @example @@ -227,9 +285,12 @@ const getConversationContent = ( * await addMemoryTool( * supermemoryClient, * "user-123", - * "User prefers React with TypeScript", - * "conversation-456", - * logger + * "User: Hello\n\nAssistant: Hi!", + * "conversation:456", + * logger, + * messages, // OpenAI messages array + * apiKey, + * baseUrl * ) * ``` */ @@ -239,8 +300,51 @@ const addMemoryTool = async ( content: string, customId: string | undefined, logger: Logger, + messages?: OpenAI.Chat.Completions.ChatCompletionMessageParam[], + apiKey?: string, + baseUrl?: string, ): Promise<void> => { try { + if (customId && messages && apiKey) { + const conversationId = customId.replace("conversation:", "") + + // Convert OpenAI messages to conversation format + const conversationMessages = messages.map((msg) => ({ + role: msg.role as "user" | "assistant" | "system" | "tool", + content: + typeof msg.content === "string" + ? msg.content + : Array.isArray(msg.content) + ? msg.content + .filter((c) => c.type === "text") + .map((c) => ({ + type: "text" as const, + text: (c as { type: "text"; text: string }).text, + })) + : "", + ...((msg as any).name && { name: (msg as any).name }), + ...((msg as any).tool_calls && { tool_calls: (msg as any).tool_calls }), + ...((msg as any).tool_call_id && { tool_call_id: (msg as any).tool_call_id }), + })) + + const response = await addConversation({ + conversationId, + messages: conversationMessages, + containerTags: [containerTag], + apiKey, + baseUrl, + }) + + logger.info("Conversation saved successfully via /v4/conversations", { + containerTag, + conversationId, + messageCount: messages.length, + responseId: response.id, + }) + return + } + + // Fallback to old behavior for non-conversation memories const response = await client.memories.add({ content, containerTags: [containerTag], @@ -293,8 +397,10 @@ export function createOpenAIMiddleware( options?: OpenAIMiddlewareOptions, ) { const logger = createLogger(options?.verbose ?? false) + const baseUrl = normalizeBaseUrl(options?.baseUrl) const client = new Supermemory({ apiKey: process.env.SUPERMEMORY_API_KEY, + ...(baseUrl !== "https://api.supermemory.ai" ? { baseURL: baseUrl } : {}), }) const conversationId = options?.conversationId @@ -327,6 +433,7 @@ export function createOpenAIMiddleware( const memoriesResponse = await supermemoryProfileSearch( containerTag, queryText, + baseUrl, ) const memoryCountStatic = memoriesResponse.profile.static?.length || 0 @@ -345,7 +452,9 @@ export function createOpenAIMiddleware( mode !== "query" ? convertProfileToMarkdown({ profile: { - static: memoriesResponse.profile.static?.map((item) => item.memory), + static: memoriesResponse.profile.static?.map( + (item) => item.memory, + ), dynamic: memoriesResponse.profile.dynamic?.map( (item) => item.memory, ), @@ -380,7 +489,9 @@ export function createOpenAIMiddleware( params: Parameters<typeof originalResponsesCreate>[0], ) => { if (!originalResponsesCreate) { - throw new Error("Responses API is not available in this OpenAI client version") + throw new Error( + "Responses API is not available in this OpenAI client version", + ) } const input = typeof params.input === "string" ? params.input : "" @@ -399,24 +510,26 @@ export function createOpenAIMiddleware( const operations: Promise<any>[] = [] if (addMemory === "always" && input?.trim()) { - const content = conversationId - ? `Input: ${input}` - : input + const content = conversationId ? `Input: ${input}` : input const customId = conversationId ? `conversation:${conversationId}` : undefined - operations.push(addMemoryTool(client, containerTag, content, customId, logger)) + operations.push( + addMemoryTool(client, containerTag, content, customId, logger), + ) } const queryText = mode !== "profile" ? input : "" - operations.push(searchAndFormatMemories( - queryText, - containerTag, - logger, - mode, - "responses", - )) + operations.push( + searchAndFormatMemories( + queryText, + containerTag, + logger, + mode, + "responses", + ), + ) const results = await Promise.all(operations) const memories = results[results.length - 1] // Memory search result is always last @@ -462,16 +575,24 @@ export function createOpenAIMiddleware( ? `conversation:${conversationId}` : undefined - operations.push(addMemoryTool(client, containerTag, content, customId, logger)) + operations.push( + addMemoryTool( + client, + containerTag, + content, + customId, + logger, + messages, + process.env.SUPERMEMORY_API_KEY, + baseUrl, + ), + ) } } - operations.push(addSystemPrompt( - messages, - containerTag, - logger, - mode, - )) + operations.push( + addSystemPrompt(messages, containerTag, logger, mode, baseUrl), + ) const results = await Promise.all(operations) const enhancedMessages = results[results.length - 1] // Enhanced messages result is always last |