aboutsummaryrefslogtreecommitdiff
path: root/packages/openai-sdk-python/src/supermemory_openai/middleware.py
diff options
context:
space:
mode:
authornexxeln <[email protected]>2025-11-11 02:16:34 +0000
committernexxeln <[email protected]>2025-11-11 02:16:34 +0000
commitea9bf13d314a605f3b80c7e8ce7e3141db5438df (patch)
treef4670a996de9097ede0fd40ae33277f70c915e75 /packages/openai-sdk-python/src/supermemory_openai/middleware.py
parentchore: update readme with selfhost link (#573) (diff)
downloadsupermemory-ea9bf13d314a605f3b80c7e8ce7e3141db5438df.tar.xz
supermemory-ea9bf13d314a605f3b80c7e8ce7e3141db5438df.zip
add openai middleware functionality for python sdk (#546)openai-middleware-python
add openai middleware functionality fix critical type errors and linting issues update readme with middleware documentation
Diffstat (limited to 'packages/openai-sdk-python/src/supermemory_openai/middleware.py')
-rw-r--r--packages/openai-sdk-python/src/supermemory_openai/middleware.py640
1 files changed, 640 insertions, 0 deletions
diff --git a/packages/openai-sdk-python/src/supermemory_openai/middleware.py b/packages/openai-sdk-python/src/supermemory_openai/middleware.py
new file mode 100644
index 00000000..a21492ae
--- /dev/null
+++ b/packages/openai-sdk-python/src/supermemory_openai/middleware.py
@@ -0,0 +1,640 @@
+"""Supermemory middleware for OpenAI clients."""
+
+from dataclasses import dataclass
+from typing import Optional, Union, Any, Literal, cast
+import asyncio
+import os
+
+from openai import OpenAI, AsyncOpenAI
+from openai.types.chat import (
+ ChatCompletionMessageParam,
+ ChatCompletionSystemMessageParam,
+)
+import supermemory
+
+from .utils import (
+ Logger,
+ create_logger,
+ get_last_user_message,
+ get_conversation_content,
+ convert_profile_to_markdown,
+)
+from .exceptions import (
+ SupermemoryConfigurationError,
+ SupermemoryAPIError,
+ SupermemoryMemoryOperationError,
+ SupermemoryNetworkError,
+)
+
+
+@dataclass
+class OpenAIMiddlewareOptions:
+ """Configuration options for OpenAI middleware."""
+
+ conversation_id: Optional[str] = None
+ verbose: bool = False
+ mode: Literal["profile", "query", "full"] = "profile"
+ add_memory: Literal["always", "never"] = "never"
+
+
+class SupermemoryProfileSearch:
+ """Type for Supermemory profile search response."""
+
+ def __init__(self, data: dict[str, Any]):
+ self.profile: dict[str, Any] = data.get("profile", {})
+ self.search_results: dict[str, Any] = data.get("searchResults", {})
+
+
+async def supermemory_profile_search(
+ container_tag: str,
+ query_text: str,
+ api_key: str,
+) -> SupermemoryProfileSearch:
+ """Search for memories using the SuperMemory profile API."""
+ payload = {
+ "containerTag": container_tag,
+ }
+ if query_text:
+ payload["q"] = query_text
+
+ try:
+ import aiohttp
+
+ async with aiohttp.ClientSession() as session:
+ async with session.post(
+ "https://api.supermemory.ai/v4/profile",
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {api_key}",
+ },
+ json=payload,
+ ) as response:
+ if not response.ok:
+ error_text = await response.text()
+ raise SupermemoryAPIError(
+ "Supermemory profile search failed",
+ status_code=response.status,
+ response_text=error_text
+ )
+
+ data = await response.json()
+ return SupermemoryProfileSearch(data)
+
+ except ImportError:
+ # Fallback to requests if aiohttp not available
+ import requests
+
+ response = requests.post(
+ "https://api.supermemory.ai/v4/profile",
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {api_key}",
+ },
+ json=payload,
+ )
+
+ if not response.ok:
+ raise SupermemoryAPIError(
+ "Supermemory profile search failed",
+ status_code=response.status_code,
+ response_text=response.text
+ )
+
+ return SupermemoryProfileSearch(response.json())
+
+
+async def add_system_prompt(
+ messages: list[ChatCompletionMessageParam],
+ container_tag: str,
+ logger: Logger,
+ mode: Literal["profile", "query", "full"],
+ api_key: str,
+) -> list[ChatCompletionMessageParam]:
+ """Add memory-enhanced system prompts to chat completion messages."""
+ system_prompt_exists = any(msg.get("role") == "system" for msg in messages)
+
+ query_text = get_last_user_message(messages) if mode != "profile" else ""
+
+ memories_response = await supermemory_profile_search(
+ container_tag, query_text, api_key
+ )
+
+ memory_count_static = len(memories_response.profile.get("static", []))
+ memory_count_dynamic = len(memories_response.profile.get("dynamic", []))
+
+ logger.info(
+ "Memory search completed",
+ {
+ "container_tag": container_tag,
+ "memory_count_static": memory_count_static,
+ "memory_count_dynamic": memory_count_dynamic,
+ "query_text": query_text[:100] + ("..." if len(query_text) > 100 else ""),
+ "mode": mode,
+ },
+ )
+
+ profile_data = ""
+ if mode != "query":
+ profile_data = convert_profile_to_markdown(
+ {
+ "profile": {
+ "static": [
+ item.get("memory", "") if isinstance(item, dict) else str(item)
+ for item in memories_response.profile.get("static", [])
+ ],
+ "dynamic": [
+ item.get("memory", "") if isinstance(item, dict) else str(item)
+ for item in memories_response.profile.get("dynamic", [])
+ ],
+ },
+ "searchResults": {
+ "results": [
+ {"memory": item.get("memory", "") if isinstance(item, dict) else str(item)}
+ for item in memories_response.search_results.get("results", [])
+ ],
+ },
+ }
+ )
+
+ search_results_memories = ""
+ if mode != "profile":
+ search_results = memories_response.search_results.get("results", [])
+ if search_results:
+ search_results_memories = (
+ f"Search results for user's recent message: \n"
+ + "\n".join(
+ f"- {result.get('memory', '') if isinstance(result, dict) else str(result)}" for result in search_results
+ )
+ )
+
+ memories = f"{profile_data}\n{search_results_memories}".strip()
+
+ if memories:
+ logger.debug(
+ "Memory content preview",
+ {
+ "content": memories,
+ "full_length": len(memories),
+ },
+ )
+
+ if system_prompt_exists:
+ logger.debug("Added memories to existing system prompt")
+ return [
+ {**msg, "content": f"{msg.get('content', '')} \n {memories}"}
+ if msg.get("role") == "system"
+ else msg
+ for msg in messages
+ ]
+
+ logger.debug("System prompt does not exist, created system prompt with memories")
+ system_message: ChatCompletionSystemMessageParam = {
+ "role": "system",
+ "content": memories,
+ }
+ return [system_message] + messages
+
+
+async def add_memory_tool(
+ client: supermemory.Supermemory,
+ container_tag: str,
+ content: str,
+ custom_id: Optional[str],
+ logger: Logger,
+) -> None:
+ """Add a new memory to the SuperMemory system."""
+ try:
+ add_params = {
+ "content": content,
+ "container_tags": [container_tag],
+ }
+ if custom_id is not None:
+ add_params["custom_id"] = custom_id
+
+ # Handle both sync and async supermemory clients
+ try:
+ response = await client.memories.add(**add_params)
+ except TypeError:
+ # If it's not awaitable, call it synchronously
+ response = client.memories.add(**add_params)
+
+ logger.info(
+ "Memory saved successfully",
+ {
+ "container_tag": container_tag,
+ "custom_id": custom_id,
+ "content_length": len(content),
+ "memory_id": response.id,
+ },
+ )
+ except (OSError, ConnectionError) as network_error:
+ logger.error(
+ "Network error while saving memory",
+ {"error": str(network_error)},
+ )
+ raise SupermemoryNetworkError(
+ "Failed to save memory due to network error",
+ network_error
+ )
+ except Exception as error:
+ logger.error(
+ "Error saving memory",
+ {"error": str(error)},
+ )
+ raise SupermemoryMemoryOperationError(
+ "Failed to save memory",
+ error
+ )
+
+
+class SupermemoryOpenAIWrapper:
+ """Wrapper for OpenAI client with Supermemory middleware."""
+
+ def __init__(
+ self,
+ openai_client: Union[OpenAI, AsyncOpenAI],
+ container_tag: str,
+ options: Optional[OpenAIMiddlewareOptions] = None,
+ ):
+ self._client: Union[OpenAI, AsyncOpenAI] = openai_client
+ self._container_tag: str = container_tag
+ self._options: OpenAIMiddlewareOptions = options or OpenAIMiddlewareOptions()
+ self._logger: Logger = create_logger(self._options.verbose)
+
+ # Track background tasks to ensure they complete
+ self._background_tasks: set[asyncio.Task] = set()
+
+ if not hasattr(supermemory, "Supermemory"):
+ raise SupermemoryConfigurationError(
+ "supermemory package is required but not found",
+ ImportError("supermemory package not installed")
+ )
+
+ api_key = self._get_api_key()
+ try:
+ self._supermemory_client: supermemory.Supermemory = supermemory.Supermemory(api_key=api_key)
+ except Exception as e:
+ raise SupermemoryConfigurationError(
+ f"Failed to initialize Supermemory client: {e}",
+ e
+ )
+
+ # Wrap the chat completions create method
+ self._wrap_chat_completions()
+
+ def _get_api_key(self) -> str:
+ """Get Supermemory API key from environment."""
+ import os
+
+ api_key = os.getenv("SUPERMEMORY_API_KEY")
+ if not api_key:
+ raise SupermemoryConfigurationError(
+ "SUPERMEMORY_API_KEY environment variable is required but not set"
+ )
+ return api_key
+
+ def _wrap_chat_completions(self) -> None:
+ """Wrap the chat completions create method with memory injection."""
+ original_create = self._client.chat.completions.create
+
+ if asyncio.iscoroutinefunction(original_create):
+
+ async def create_with_memory(
+ **kwargs: Any,
+ ) -> Any:
+ return await self._create_with_memory_async(original_create, **kwargs)
+ else:
+
+ def create_with_memory(
+ **kwargs: Any,
+ ) -> Any:
+ return self._create_with_memory_sync(original_create, **kwargs)
+
+ # Replace the create method with our wrapper
+ setattr(self._client.chat.completions, "create", create_with_memory)
+
+ async def _create_with_memory_async(
+ self,
+ original_create: Any,
+ **kwargs: Any,
+ ) -> Any:
+ """Async version of create with memory injection."""
+ messages = kwargs.get("messages", [])
+
+ if self._options.add_memory == "always":
+ user_message = get_last_user_message(messages)
+ if user_message and user_message.strip():
+ content = (
+ get_conversation_content(messages)
+ if self._options.conversation_id
+ else user_message
+ )
+ custom_id = (
+ f"conversation:{self._options.conversation_id}"
+ if self._options.conversation_id
+ else None
+ )
+
+ # Create background task for memory storage
+ task = asyncio.create_task(
+ add_memory_tool(
+ self._supermemory_client,
+ self._container_tag,
+ content,
+ custom_id,
+ self._logger,
+ )
+ )
+
+ # Track the task and set up cleanup
+ self._background_tasks.add(task)
+ task.add_done_callback(self._background_tasks.discard)
+
+ # Log any exceptions but don't fail the main request
+ def handle_task_exception(task_obj):
+ try:
+ if task_obj.exception() is not None:
+ exception = task_obj.exception()
+ if isinstance(exception, (SupermemoryNetworkError, SupermemoryAPIError)):
+ self._logger.warn(
+ "Background memory storage failed",
+ {"error": str(exception), "type": type(exception).__name__}
+ )
+ else:
+ self._logger.error(
+ "Unexpected error in background memory storage",
+ {"error": str(exception), "type": type(exception).__name__}
+ )
+ except asyncio.CancelledError:
+ self._logger.debug("Memory storage task was cancelled")
+
+ task.add_done_callback(handle_task_exception)
+
+ if self._options.mode != "profile":
+ user_message = get_last_user_message(messages)
+ if not user_message:
+ self._logger.debug("No user message found, skipping memory search")
+ return await original_create(**kwargs)
+
+ self._logger.info(
+ "Starting memory search",
+ {
+ "container_tag": self._container_tag,
+ "conversation_id": self._options.conversation_id,
+ "mode": self._options.mode,
+ },
+ )
+
+ enhanced_messages = await add_system_prompt(
+ messages,
+ self._container_tag,
+ self._logger,
+ self._options.mode,
+ self._get_api_key(),
+ )
+
+ kwargs["messages"] = enhanced_messages
+ return await original_create(**kwargs)
+
+ def _create_with_memory_sync(
+ self,
+ original_create: Any,
+ **kwargs: Any,
+ ) -> Any:
+ """Sync version of create with memory injection."""
+ # For sync clients, we implement a simplified version without background tasks
+ messages = kwargs.get("messages", [])
+
+ # Handle memory addition synchronously if needed
+ if self._options.add_memory == "always":
+ user_message = get_last_user_message(messages)
+ if user_message and user_message.strip():
+ content = (
+ get_conversation_content(messages)
+ if self._options.conversation_id
+ else user_message
+ )
+ custom_id = (
+ f"conversation:{self._options.conversation_id}"
+ if self._options.conversation_id
+ else None
+ )
+
+ # Use asyncio.run() for the memory addition
+ try:
+ asyncio.run(
+ add_memory_tool(
+ self._supermemory_client,
+ self._container_tag,
+ content,
+ custom_id,
+ self._logger,
+ )
+ )
+ except RuntimeError as e:
+ if "cannot be called from a running event loop" in str(e):
+ # We're in an async context, log warning and skip memory saving
+ self._logger.warn(
+ "Cannot save memory in sync client from async context",
+ {"error": str(e)}
+ )
+ else:
+ raise
+ except SupermemoryNetworkError as e:
+ # Network errors are expected, log as warning
+ self._logger.warn("Network error saving memory", {"error": str(e)})
+ except (SupermemoryAPIError, SupermemoryMemoryOperationError) as e:
+ # API/memory errors are concerning, log as error
+ self._logger.error("Failed to save memory", {"error": str(e)})
+ except Exception as e:
+ # Unexpected errors should be investigated
+ self._logger.error(
+ "Unexpected error saving memory",
+ {"error": str(e), "type": type(e).__name__}
+ )
+
+ # Handle memory search and injection
+ if self._options.mode != "profile":
+ user_message = get_last_user_message(messages)
+ if not user_message:
+ self._logger.debug("No user message found, skipping memory search")
+ return original_create(**kwargs)
+
+ self._logger.info("Starting memory search", {
+ "container_tag": self._container_tag,
+ "conversation_id": self._options.conversation_id,
+ "mode": self._options.mode,
+ })
+
+ # Use asyncio.run() for memory search and injection
+ try:
+ enhanced_messages = asyncio.run(
+ add_system_prompt(
+ messages,
+ self._container_tag,
+ self._logger,
+ self._options.mode,
+ self._get_api_key(),
+ )
+ )
+ except RuntimeError as e:
+ if "cannot be called from a running event loop" in str(e):
+ # We're in an async context, run in a separate thread
+ import concurrent.futures
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
+ future = executor.submit(
+ asyncio.run,
+ add_system_prompt(
+ messages,
+ self._container_tag,
+ self._logger,
+ self._options.mode,
+ self._get_api_key(),
+ )
+ )
+ enhanced_messages = future.result()
+ else:
+ raise
+
+ kwargs["messages"] = enhanced_messages
+ return original_create(**kwargs)
+
+ async def wait_for_background_tasks(self, timeout: Optional[float] = 10.0) -> None:
+ """
+ Wait for all background memory storage tasks to complete.
+
+ Args:
+ timeout: Maximum time to wait in seconds. None for no timeout.
+
+ Raises:
+ asyncio.TimeoutError: If tasks don't complete within timeout
+ """
+ if not self._background_tasks:
+ return
+
+ self._logger.debug(f"Waiting for {len(self._background_tasks)} background tasks to complete")
+
+ try:
+ if timeout is not None:
+ await asyncio.wait_for(
+ asyncio.gather(*self._background_tasks, return_exceptions=True),
+ timeout=timeout
+ )
+ else:
+ await asyncio.gather(*self._background_tasks, return_exceptions=True)
+
+ self._logger.debug("All background tasks completed")
+ except asyncio.TimeoutError:
+ self._logger.warn(f"Background tasks did not complete within {timeout}s timeout")
+ # Cancel remaining tasks
+ for task in self._background_tasks:
+ if not task.done():
+ task.cancel()
+ raise
+
+ def cancel_background_tasks(self) -> None:
+ """Cancel all pending background tasks."""
+ cancelled_count = 0
+ for task in self._background_tasks:
+ if not task.done():
+ task.cancel()
+ cancelled_count += 1
+
+ if cancelled_count > 0:
+ self._logger.debug(f"Cancelled {cancelled_count} pending background tasks")
+
+ async def __aenter__(self):
+ """Async context manager entry."""
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ """Async context manager exit - wait for background tasks."""
+ try:
+ await self.wait_for_background_tasks(timeout=5.0)
+ except asyncio.TimeoutError:
+ self._logger.warn("Some background memory tasks did not complete on exit")
+
+ def __enter__(self):
+ """Sync context manager entry."""
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """Sync context manager exit - attempt to wait for background tasks."""
+ if self._background_tasks:
+ try:
+ # Try to wait for background tasks in sync context
+ asyncio.run(self.wait_for_background_tasks(timeout=5.0))
+ except RuntimeError as e:
+ if "cannot be called from a running event loop" in str(e):
+ # In async context, just cancel the tasks
+ self._logger.warn(
+ "Cannot wait for background tasks in sync context from async environment. "
+ "Use async context manager or call wait_for_background_tasks() manually."
+ )
+ self.cancel_background_tasks()
+ else:
+ raise
+ except asyncio.TimeoutError:
+ self._logger.warn("Some background memory tasks did not complete on exit")
+ self.cancel_background_tasks()
+
+ def __getattr__(self, name: str) -> Any:
+ """Delegate all other attributes to the wrapped client."""
+ return getattr(self._client, name)
+
+
+def with_supermemory(
+ openai_client: Union[OpenAI, AsyncOpenAI],
+ container_tag: str,
+ options: Optional[OpenAIMiddlewareOptions] = None,
+) -> Union[OpenAI, AsyncOpenAI]:
+ """
+ Wraps an OpenAI client with SuperMemory middleware to automatically inject relevant memories
+ into the system prompt based on the user's message content.
+
+ This middleware searches the supermemory API for relevant memories using the container tag
+ and user message, then either appends memories to an existing system prompt or creates
+ a new system prompt with the memories.
+
+ Args:
+ openai_client: The OpenAI client to wrap with SuperMemory middleware
+ container_tag: The container tag/identifier for memory search (e.g., user ID, project ID)
+ options: Optional configuration options for the middleware
+
+ Returns:
+ An OpenAI client with SuperMemory middleware injected
+
+ Example:
+ ```python
+ from supermemory_openai import with_supermemory, OpenAIMiddlewareOptions
+ from openai import OpenAI
+
+ # Create OpenAI client with supermemory middleware
+ openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
+ openai_with_supermemory = with_supermemory(
+ openai,
+ "user-123",
+ OpenAIMiddlewareOptions(
+ conversation_id="conversation-456",
+ mode="full",
+ add_memory="always"
+ )
+ )
+
+ # Use normally - memories will be automatically injected
+ response = await openai_with_supermemory.chat.completions.create(
+ model="gpt-4",
+ messages=[
+ {"role": "user", "content": "What's my favorite programming language?"}
+ ]
+ )
+ ```
+
+ Raises:
+ ValueError: When SUPERMEMORY_API_KEY environment variable is not set
+ Exception: When supermemory API request fails
+ """
+ wrapper = SupermemoryOpenAIWrapper(openai_client, container_tag, options)
+ # Return the wrapper, which delegates all attributes to the original client
+ return cast(Union[OpenAI, AsyncOpenAI], wrapper)