"""Supermemory Pipecat service integration.
This module provides a memory service that integrates with Supermemory to store
and retrieve conversational memories, enhancing LLM context with relevant
historical information.
"""
import asyncio
import json
import os
import re
from typing import Any, Dict, List, Literal, Optional
from loguru import logger
from pydantic import BaseModel, Field
from pipecat.frames.frames import Frame, InputAudioRawFrame, LLMContextFrame, LLMMessagesFrame
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pydantic import BaseModel, Field
from .exceptions import ConfigurationError, MemoryRetrievalError
from .utils import deduplicate_memories, format_memories_to_text, get_last_user_message
try:
import supermemory
except ImportError:
supermemory = None # type: ignore
# XML tags for memory injection (replacement instead of accumulation)
MEMORY_TAG_START = ""
MEMORY_TAG_END = ""
MEMORY_TAG_PATTERN = re.compile(r".*?", re.DOTALL)
class SupermemoryPipecatService(FrameProcessor):
"""Memory service that integrates Supermemory with Pipecat pipelines.
This service intercepts message frames in the pipeline, retrieves relevant
memories from Supermemory, and enhances the context before passing downstream.
Example:
```python
from supermemory_pipecat import SupermemoryPipecatService
memory = SupermemoryPipecatService(
api_key=os.getenv("SUPERMEMORY_API_KEY"),
user_id="user-123",
)
```
"""
class InputParams(BaseModel):
"""Configuration parameters for memory retrieval and injection.
Attributes:
search_limit: Maximum number of memories to retrieve per query.
search_threshold: Minimum similarity threshold (0.0-1.0).
system_prompt: Prefix text for memory context.
mode: Memory retrieval mode - "profile", "query", or "full".
inject_mode: How to inject memories - "auto", "system", or "user".
"""
search_limit: int = Field(default=10, ge=1)
search_threshold: float = Field(default=0.1, ge=0.0, le=1.0)
system_prompt: str = Field(default="Based on previous conversations, I recall:\n\n")
mode: Literal["profile", "query", "full"] = Field(default="full")
inject_mode: Literal["auto", "system", "user"] = Field(default="auto")
def __init__(
self,
*,
api_key: Optional[str] = None,
user_id: str,
session_id: Optional[str] = None,
params: Optional[InputParams] = None,
base_url: Optional[str] = None,
):
"""Initialize the Supermemory Pipecat service.
Args:
api_key: Supermemory API key. Falls back to SUPERMEMORY_API_KEY env var.
user_id: The user ID - used as container_tag for memory scoping.
session_id: Session/conversation ID for grouping memories.
params: Configuration parameters for memory retrieval.
base_url: Optional custom base URL for Supermemory API.
Raises:
ConfigurationError: If API key is missing or user_id not provided.
"""
super().__init__()
self.api_key = api_key or os.getenv("SUPERMEMORY_API_KEY")
if not self.api_key:
raise ConfigurationError(
"API key is required. Provide api_key parameter or set SUPERMEMORY_API_KEY environment variable."
)
if not user_id:
raise ConfigurationError("user_id is required")
self.user_id = user_id
self.container_tag = user_id
self.session_id = session_id
self.params = params or SupermemoryPipecatService.InputParams()
self._supermemory_client = None
if supermemory is not None:
try:
self._supermemory_client = supermemory.AsyncSupermemory(
api_key=self.api_key,
base_url=base_url,
)
except Exception as e:
logger.warning(f"Failed to initialize Supermemory client: {e}")
self._messages_sent_count: int = 0
self._last_query: Optional[str] = None
self._audio_frames_detected: bool = False
async def _retrieve_memories(self, query: str) -> Dict[str, Any]:
"""Retrieve relevant memories from Supermemory.
Args:
query: The search query for memory retrieval.
Returns:
Dictionary containing profile (static/dynamic) and search results.
Raises:
MemoryRetrievalError: If retrieval fails.
"""
if self._supermemory_client is None:
raise MemoryRetrievalError(
"Supermemory client not initialized. Install with: pip install supermemory"
)
try:
kwargs: Dict[str, Any] = {"container_tag": self.container_tag}
if self.params.mode != "profile" and query:
kwargs["q"] = query
kwargs["threshold"] = self.params.search_threshold
kwargs["extra_body"] = {"limit": self.params.search_limit}
response = await self._supermemory_client.profile(**kwargs)
search_results = []
if response.search_results and response.search_results.results:
search_results = response.search_results.results
return {
"profile": {
"static": response.profile.static,
"dynamic": response.profile.dynamic,
},
"search_results": search_results,
}
except Exception as e:
logger.error(f"Error retrieving memories: {e}")
raise MemoryRetrievalError("Failed to retrieve memories", e)
async def _store_messages(self, messages: List[Dict[str, Any]]) -> None:
"""Store messages in Supermemory (non-blocking, fire-and-forget)."""
if self._supermemory_client is None or not messages:
return
try:
add_params: Dict[str, Any] = {
"content": json.dumps(messages),
"container_tags": [self.container_tag],
"metadata": {"platform": "pipecat"},
}
if self.session_id:
add_params["custom_id"] = self.session_id
await self._supermemory_client.memories.add(**add_params)
except Exception as e:
logger.error(f"Error storing messages: {e}")
def _enhance_context_with_memories(
self,
context: LLMContext,
query: str,
memories_data: Dict[str, Any],
) -> None:
"""Enhance LLM context with retrieved memories.
Uses XML tags ... to wrap memories,
allowing replacement on each turn instead of accumulation.
Args:
context: The LLM context to enhance.
query: The query used for retrieval.
memories_data: Memory data from Supermemory API.
"""
if self._last_query == query:
return
self._last_query = query
profile = memories_data["profile"]
deduplicated = deduplicate_memories(
static=profile["static"],
dynamic=profile["dynamic"],
search_results=memories_data["search_results"],
)
total_memories = (
len(deduplicated["static"])
+ len(deduplicated["dynamic"])
+ len(deduplicated["search_results"])
)
if total_memories == 0:
return
include_profile = self.params.mode in ("profile", "full")
include_search = self.params.mode in ("query", "full")
memory_text = format_memories_to_text(
deduplicated,
system_prompt=self.params.system_prompt,
include_static=include_profile,
include_dynamic=include_profile,
include_search=include_search,
)
if not memory_text:
return
tagged_memory = f"{MEMORY_TAG_START}\n{memory_text}\n{MEMORY_TAG_END}"
inject_to_system = self.params.inject_mode == "system" or (
self.params.inject_mode == "auto" and self._audio_frames_detected
)
messages = context.get_messages()
if inject_to_system:
system_idx = None
for i, msg in enumerate(messages):
if msg.get("role") == "system":
system_idx = i
break
if system_idx is not None:
existing_content = messages[system_idx].get("content", "")
if MEMORY_TAG_PATTERN.search(existing_content):
messages[system_idx]["content"] = MEMORY_TAG_PATTERN.sub(
tagged_memory, existing_content
)
else:
messages[system_idx]["content"] = f"{existing_content}\n\n{tagged_memory}"
else:
messages.insert(0, {"role": "system", "content": tagged_memory})
else:
# Remove previous memory message if exists
for i in range(len(messages) - 1, -1, -1):
msg = messages[i]
if msg.get("role") == "user" and MEMORY_TAG_START in msg.get("content", ""):
messages.pop(i)
break
context.add_message({"role": "user", "content": tagged_memory})
async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
"""Process frames, intercept context frames for memory integration."""
await super().process_frame(frame, direction)
# Auto-detect speech-to-speech mode via audio frames
if isinstance(frame, InputAudioRawFrame):
if not self._audio_frames_detected:
self._audio_frames_detected = True
await self.push_frame(frame, direction)
return
context = None
messages = None
if isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)):
context = frame.context
elif isinstance(frame, LLMMessagesFrame):
messages = frame.messages
context = LLMContext(messages)
if context:
try:
context_messages = context.get_messages()
latest_user_message = get_last_user_message(context_messages)
if latest_user_message:
try:
memories_data = await self._retrieve_memories(latest_user_message)
self._enhance_context_with_memories(
context, latest_user_message, memories_data
)
except MemoryRetrievalError as e:
logger.warning(f"Memory retrieval failed: {e}")
# Store unsent messages (user and assistant only)
storable_messages = [
msg for msg in context_messages if msg["role"] in ("user", "assistant")
]
unsent_messages = storable_messages[self._messages_sent_count :]
if unsent_messages:
asyncio.create_task(self._store_messages(unsent_messages))
self._messages_sent_count = len(storable_messages)
if messages is not None:
await self.push_frame(LLMMessagesFrame(context.get_messages()))
else:
await self.push_frame(frame)
except Exception as e:
logger.error(f"Error processing frame: {e}")
await self.push_frame(frame)
else:
await self.push_frame(frame, direction)
def reset_memory_tracking(self) -> None:
"""Reset memory tracking state for a new conversation."""
self._messages_sent_count = 0
self._last_query = None
self._audio_frames_detected = False