diff options
Diffstat (limited to 'packages/pipecat-sdk-python/src/supermemory_pipecat/service.py')
| -rw-r--r-- | packages/pipecat-sdk-python/src/supermemory_pipecat/service.py | 232 |
1 files changed, 94 insertions, 138 deletions
diff --git a/packages/pipecat-sdk-python/src/supermemory_pipecat/service.py b/packages/pipecat-sdk-python/src/supermemory_pipecat/service.py index 01bc03df..2aef866b 100644 --- a/packages/pipecat-sdk-python/src/supermemory_pipecat/service.py +++ b/packages/pipecat-sdk-python/src/supermemory_pipecat/service.py @@ -8,37 +8,37 @@ historical information. import asyncio import json import os +import re from typing import Any, Dict, List, Literal, Optional from loguru import logger -from pipecat.frames.frames import Frame, LLMContextFrame, LLMMessagesFrame +from pydantic import BaseModel, Field + +from pipecat.frames.frames import Frame, InputAudioRawFrame, LLMContextFrame, LLMMessagesFrame from pipecat.processors.aggregators.llm_context import LLMContext from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pydantic import BaseModel, Field -from .exceptions import ( - ConfigurationError, - MemoryRetrievalError, -) -from .utils import ( - deduplicate_memories, - format_memories_to_text, - get_last_user_message, -) +from .exceptions import ConfigurationError, MemoryRetrievalError +from .utils import deduplicate_memories, format_memories_to_text, get_last_user_message try: import supermemory except ImportError: supermemory = None # type: ignore +# XML tags for memory injection (replacement instead of accumulation) +MEMORY_TAG_START = "<user_memories>" +MEMORY_TAG_END = "</user_memories>" +MEMORY_TAG_PATTERN = re.compile(r"<user_memories>.*?</user_memories>", re.DOTALL) + class SupermemoryPipecatService(FrameProcessor): - """A memory service that integrates Supermemory with Pipecat pipelines. + """Memory service that integrates Supermemory with Pipecat pipelines. This service intercepts message frames in the pipeline, retrieves relevant - memories from Supermemory, enhances the context, and optionally stores - new conversations. + memories from Supermemory, and enhances the context before passing downstream. Example: ```python @@ -48,34 +48,25 @@ class SupermemoryPipecatService(FrameProcessor): api_key=os.getenv("SUPERMEMORY_API_KEY"), user_id="user-123", ) - - pipeline = Pipeline([ - transport.input(), - stt, - user_context, - memory, # Memory service enhances context here - llm, - transport.output(), - ]) ``` """ class InputParams(BaseModel): - """Configuration parameters for Supermemory Pipecat service. + """Configuration parameters for memory retrieval and injection. - Parameters: + Attributes: search_limit: Maximum number of memories to retrieve per query. - search_threshold: Minimum similarity threshold for memory retrieval. - system_prompt: Prefix text for memory context messages. + search_threshold: Minimum similarity threshold (0.0-1.0). + system_prompt: Prefix text for memory context. mode: Memory retrieval mode - "profile", "query", or "full". + inject_mode: How to inject memories - "auto", "system", or "user". """ search_limit: int = Field(default=10, ge=1) search_threshold: float = Field(default=0.1, ge=0.0, le=1.0) - system_prompt: str = Field( - default="Based on previous conversations, I recall:\n\n" - ) + system_prompt: str = Field(default="Based on previous conversations, I recall:\n\n") mode: Literal["profile", "query", "full"] = Field(default="full") + inject_mode: Literal["auto", "system", "user"] = Field(default="auto") def __init__( self, @@ -89,10 +80,10 @@ class SupermemoryPipecatService(FrameProcessor): """Initialize the Supermemory Pipecat service. Args: - api_key: The API key for Supermemory. Falls back to SUPERMEMORY_API_KEY env var. - user_id: The user ID - used as container_tag for memory scoping (REQUIRED). - session_id: Session/conversation ID for grouping memories (optional). - params: Configuration parameters for memory retrieval and storage. + api_key: Supermemory API key. Falls back to SUPERMEMORY_API_KEY env var. + user_id: The user ID - used as container_tag for memory scoping. + session_id: Session/conversation ID for grouping memories. + params: Configuration parameters for memory retrieval. base_url: Optional custom base URL for Supermemory API. Raises: @@ -100,25 +91,20 @@ class SupermemoryPipecatService(FrameProcessor): """ super().__init__() - # Get API key self.api_key = api_key or os.getenv("SUPERMEMORY_API_KEY") if not self.api_key: raise ConfigurationError( "API key is required. Provide api_key parameter or set SUPERMEMORY_API_KEY environment variable." ) - # user_id is required and used directly as container_tag if not user_id: raise ConfigurationError("user_id is required") self.user_id = user_id - self.container_tag = user_id # container_tag = user_id directly - self.session_id = session_id # optional session/conversation ID - - # Configuration + self.container_tag = user_id + self.session_id = session_id self.params = params or SupermemoryPipecatService.InputParams() - # Initialize async Supermemory client self._supermemory_client = None if supermemory is not None: try: @@ -129,25 +115,21 @@ class SupermemoryPipecatService(FrameProcessor): except Exception as e: logger.warning(f"Failed to initialize Supermemory client: {e}") - # Track how many messages we've already sent to memory self._messages_sent_count: int = 0 - - # Track last query to avoid duplicate processing self._last_query: Optional[str] = None - - logger.info( - f"Initialized SupermemoryPipecatService with " - f"user_id={user_id}, session_id={session_id}" - ) + self._audio_frames_detected: bool = False async def _retrieve_memories(self, query: str) -> Dict[str, Any]: """Retrieve relevant memories from Supermemory. Args: - query: The query to search for relevant memories. + query: The search query for memory retrieval. Returns: - Dictionary containing profile and search results. + Dictionary containing profile (static/dynamic) and search results. + + Raises: + MemoryRetrievalError: If retrieval fails. """ if self._supermemory_client is None: raise MemoryRetrievalError( @@ -155,29 +137,20 @@ class SupermemoryPipecatService(FrameProcessor): ) try: - logger.debug(f"Retrieving memories for query: {query[:100]}...") + kwargs: Dict[str, Any] = {"container_tag": self.container_tag} - # Build kwargs for profile request - kwargs: Dict[str, Any] = { - "container_tag": self.container_tag, - } - - # Add query for search modes if self.params.mode != "profile" and query: kwargs["q"] = query kwargs["threshold"] = self.params.search_threshold - # Pass limit via extra_body since SDK doesn't have direct param kwargs["extra_body"] = {"limit": self.params.search_limit} - # Use SDK's profile method response = await self._supermemory_client.profile(**kwargs) - # Extract memory strings from SDK response search_results = [] if response.search_results and response.search_results.results: - search_results = [r["memory"] for r in response.search_results.results] + search_results = response.search_results.results - data: Dict[str, Any] = { + return { "profile": { "static": response.profile.static, "dynamic": response.profile.dynamic, @@ -185,53 +158,28 @@ class SupermemoryPipecatService(FrameProcessor): "search_results": search_results, } - logger.debug( - f"Retrieved memories - static: {len(data['profile']['static'])}, " - f"dynamic: {len(data['profile']['dynamic'])}, " - f"search: {len(data['search_results'])}" - ) - return data - except Exception as e: logger.error(f"Error retrieving memories: {e}") raise MemoryRetrievalError("Failed to retrieve memories", e) async def _store_messages(self, messages: List[Dict[str, Any]]) -> None: - """Store messages in Supermemory. - - Args: - messages: List of message dicts with 'role' and 'content' keys. - """ - if self._supermemory_client is None: - logger.warning( - "Supermemory client not initialized, skipping memory storage" - ) - return - - if not messages: + """Store messages in Supermemory (non-blocking, fire-and-forget).""" + if self._supermemory_client is None or not messages: return try: - # Format messages as JSON array - formatted_content = json.dumps(messages) - - logger.debug(f"Storing {len(messages)} messages to Supermemory") - - # Build storage params add_params: Dict[str, Any] = { - "content": formatted_content, + "content": json.dumps(messages), "container_tags": [self.container_tag], "metadata": {"platform": "pipecat"}, } if self.session_id: - add_params["custom_id"] = f"{self.session_id}" + add_params["custom_id"] = self.session_id - await self._supermemory_client.add(**add_params) - logger.debug(f"Successfully stored {len(messages)} messages in Supermemory") + await self._supermemory_client.memories.add(**add_params) except Exception as e: - # Don't fail the pipeline on storage errors - logger.error(f"Error storing messages in Supermemory: {e}") + logger.error(f"Error storing messages: {e}") def _enhance_context_with_memories( self, @@ -239,29 +187,28 @@ class SupermemoryPipecatService(FrameProcessor): query: str, memories_data: Dict[str, Any], ) -> None: - """Enhance the LLM context with relevant memories. + """Enhance LLM context with retrieved memories. + + Uses XML tags <user_memories>...</user_memories> to wrap memories, + allowing replacement on each turn instead of accumulation. Args: context: The LLM context to enhance. query: The query used for retrieval. - memories_data: Raw memory data from Supermemory API. + memories_data: Memory data from Supermemory API. """ - # Skip if same query (avoid duplicate processing) if self._last_query == query: return self._last_query = query - # Extract and deduplicate memories profile = memories_data["profile"] - deduplicated = deduplicate_memories( static=profile["static"], dynamic=profile["dynamic"], search_results=memories_data["search_results"], ) - # Check if we have any memories total_memories = ( len(deduplicated["static"]) + len(deduplicated["dynamic"]) @@ -269,10 +216,8 @@ class SupermemoryPipecatService(FrameProcessor): ) if total_memories == 0: - logger.debug("No memories found to inject") return - # Format memories based on mode include_profile = self.params.mode in ("profile", "full") include_search = self.params.mode in ("query", "full") @@ -287,24 +232,55 @@ class SupermemoryPipecatService(FrameProcessor): if not memory_text: return - # Inject memories into context as user message - context.add_message({"role": "user", "content": memory_text}) + tagged_memory = f"{MEMORY_TAG_START}\n{memory_text}\n{MEMORY_TAG_END}" - logger.debug(f"Enhanced context with {total_memories} memories") + inject_to_system = self.params.inject_mode == "system" or ( + self.params.inject_mode == "auto" and self._audio_frames_detected + ) - async def process_frame(self, frame: Frame, direction: FrameDirection) -> None: - """Process incoming frames, intercept context frames for memory integration. + messages = context.get_messages() + + if inject_to_system: + system_idx = None + for i, msg in enumerate(messages): + if msg.get("role") == "system": + system_idx = i + break + + if system_idx is not None: + existing_content = messages[system_idx].get("content", "") + if MEMORY_TAG_PATTERN.search(existing_content): + messages[system_idx]["content"] = MEMORY_TAG_PATTERN.sub( + tagged_memory, existing_content + ) + else: + messages[system_idx]["content"] = f"{existing_content}\n\n{tagged_memory}" + else: + messages.insert(0, {"role": "system", "content": tagged_memory}) + else: + # Remove previous memory message if exists + for i in range(len(messages) - 1, -1, -1): + msg = messages[i] + if msg.get("role") == "user" and MEMORY_TAG_START in msg.get("content", ""): + messages.pop(i) + break - Args: - frame: The incoming frame to process. - direction: The direction of frame flow in the pipeline. - """ + context.add_message({"role": "user", "content": tagged_memory}) + + async def process_frame(self, frame: Frame, direction: FrameDirection) -> None: + """Process frames, intercept context frames for memory integration.""" await super().process_frame(frame, direction) + # Auto-detect speech-to-speech mode via audio frames + if isinstance(frame, InputAudioRawFrame): + if not self._audio_frames_detected: + self._audio_frames_detected = True + await self.push_frame(frame, direction) + return + context = None messages = None - # Handle different frame types if isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)): context = frame.context elif isinstance(frame, LLMMessagesFrame): @@ -313,29 +289,21 @@ class SupermemoryPipecatService(FrameProcessor): if context: try: - # Get messages from context context_messages = context.get_messages() latest_user_message = get_last_user_message(context_messages) if latest_user_message: - # Retrieve memories from Supermemory try: - memories_data = await self._retrieve_memories( - latest_user_message - ) + memories_data = await self._retrieve_memories(latest_user_message) self._enhance_context_with_memories( context, latest_user_message, memories_data ) except MemoryRetrievalError as e: - logger.warning( - f"Memory retrieval failed, continuing without memories: {e}" - ) + logger.warning(f"Memory retrieval failed: {e}") - # Store unsent messages (user and assistant only, skip system) + # Store unsent messages (user and assistant only) storable_messages = [ - msg - for msg in context_messages - if msg["role"] in ("user", "assistant") + msg for msg in context_messages if msg["role"] in ("user", "assistant") ] unsent_messages = storable_messages[self._messages_sent_count :] @@ -343,31 +311,19 @@ class SupermemoryPipecatService(FrameProcessor): asyncio.create_task(self._store_messages(unsent_messages)) self._messages_sent_count = len(storable_messages) - # Pass the frame downstream if messages is not None: - # For LLMMessagesFrame, create new frame with enhanced messages await self.push_frame(LLMMessagesFrame(context.get_messages())) else: - # For context frames, pass the enhanced frame await self.push_frame(frame) except Exception as e: - logger.error(f"Error processing frame with Supermemory: {e}") - # Still pass the original frame through on error + logger.error(f"Error processing frame: {e}") await self.push_frame(frame) else: - # Non-context frames pass through unchanged await self.push_frame(frame, direction) - def get_messages_sent_count(self) -> int: - """Get the count of messages sent to memory. - - Returns: - Number of messages already sent to Supermemory. - """ - return self._messages_sent_count - def reset_memory_tracking(self) -> None: - """Reset memory tracking for a new conversation.""" + """Reset memory tracking state for a new conversation.""" self._messages_sent_count = 0 self._last_query = None + self._audio_frames_detected = False |