aboutsummaryrefslogtreecommitdiff
path: root/packages/pipecat-sdk-python/src/supermemory_pipecat/service.py
blob: 2aef866bf41c64d6f8bfcd013e1a888ee4f83f35 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
"""Supermemory Pipecat service integration.

This module provides a memory service that integrates with Supermemory to store
and retrieve conversational memories, enhancing LLM context with relevant
historical information.
"""

import asyncio
import json
import os
import re
from typing import Any, Dict, List, Literal, Optional

from loguru import logger
from pydantic import BaseModel, Field

from pipecat.frames.frames import Frame, InputAudioRawFrame, LLMContextFrame, LLMMessagesFrame
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pydantic import BaseModel, Field

from .exceptions import ConfigurationError, MemoryRetrievalError
from .utils import deduplicate_memories, format_memories_to_text, get_last_user_message

try:
    import supermemory
except ImportError:
    supermemory = None  # type: ignore

# XML tags for memory injection (replacement instead of accumulation)
MEMORY_TAG_START = "<user_memories>"
MEMORY_TAG_END = "</user_memories>"
MEMORY_TAG_PATTERN = re.compile(r"<user_memories>.*?</user_memories>", re.DOTALL)


class SupermemoryPipecatService(FrameProcessor):
    """Memory service that integrates Supermemory with Pipecat pipelines.

    This service intercepts message frames in the pipeline, retrieves relevant
    memories from Supermemory, and enhances the context before passing downstream.

    Example:
        ```python
        from supermemory_pipecat import SupermemoryPipecatService

        memory = SupermemoryPipecatService(
            api_key=os.getenv("SUPERMEMORY_API_KEY"),
            user_id="user-123",
        )
        ```
    """

    class InputParams(BaseModel):
        """Configuration parameters for memory retrieval and injection.

        Attributes:
            search_limit: Maximum number of memories to retrieve per query.
            search_threshold: Minimum similarity threshold (0.0-1.0).
            system_prompt: Prefix text for memory context.
            mode: Memory retrieval mode - "profile", "query", or "full".
            inject_mode: How to inject memories - "auto", "system", or "user".
        """

        search_limit: int = Field(default=10, ge=1)
        search_threshold: float = Field(default=0.1, ge=0.0, le=1.0)
        system_prompt: str = Field(default="Based on previous conversations, I recall:\n\n")
        mode: Literal["profile", "query", "full"] = Field(default="full")
        inject_mode: Literal["auto", "system", "user"] = Field(default="auto")

    def __init__(
        self,
        *,
        api_key: Optional[str] = None,
        user_id: str,
        session_id: Optional[str] = None,
        params: Optional[InputParams] = None,
        base_url: Optional[str] = None,
    ):
        """Initialize the Supermemory Pipecat service.

        Args:
            api_key: Supermemory API key. Falls back to SUPERMEMORY_API_KEY env var.
            user_id: The user ID - used as container_tag for memory scoping.
            session_id: Session/conversation ID for grouping memories.
            params: Configuration parameters for memory retrieval.
            base_url: Optional custom base URL for Supermemory API.

        Raises:
            ConfigurationError: If API key is missing or user_id not provided.
        """
        super().__init__()

        self.api_key = api_key or os.getenv("SUPERMEMORY_API_KEY")
        if not self.api_key:
            raise ConfigurationError(
                "API key is required. Provide api_key parameter or set SUPERMEMORY_API_KEY environment variable."
            )

        if not user_id:
            raise ConfigurationError("user_id is required")

        self.user_id = user_id
        self.container_tag = user_id
        self.session_id = session_id
        self.params = params or SupermemoryPipecatService.InputParams()

        self._supermemory_client = None
        if supermemory is not None:
            try:
                self._supermemory_client = supermemory.AsyncSupermemory(
                    api_key=self.api_key,
                    base_url=base_url,
                )
            except Exception as e:
                logger.warning(f"Failed to initialize Supermemory client: {e}")

        self._messages_sent_count: int = 0
        self._last_query: Optional[str] = None
        self._audio_frames_detected: bool = False

    async def _retrieve_memories(self, query: str) -> Dict[str, Any]:
        """Retrieve relevant memories from Supermemory.

        Args:
            query: The search query for memory retrieval.

        Returns:
            Dictionary containing profile (static/dynamic) and search results.

        Raises:
            MemoryRetrievalError: If retrieval fails.
        """
        if self._supermemory_client is None:
            raise MemoryRetrievalError(
                "Supermemory client not initialized. Install with: pip install supermemory"
            )

        try:
            kwargs: Dict[str, Any] = {"container_tag": self.container_tag}

            if self.params.mode != "profile" and query:
                kwargs["q"] = query
                kwargs["threshold"] = self.params.search_threshold
                kwargs["extra_body"] = {"limit": self.params.search_limit}

            response = await self._supermemory_client.profile(**kwargs)

            search_results = []
            if response.search_results and response.search_results.results:
                search_results = response.search_results.results

            return {
                "profile": {
                    "static": response.profile.static,
                    "dynamic": response.profile.dynamic,
                },
                "search_results": search_results,
            }

        except Exception as e:
            logger.error(f"Error retrieving memories: {e}")
            raise MemoryRetrievalError("Failed to retrieve memories", e)

    async def _store_messages(self, messages: List[Dict[str, Any]]) -> None:
        """Store messages in Supermemory (non-blocking, fire-and-forget)."""
        if self._supermemory_client is None or not messages:
            return

        try:
            add_params: Dict[str, Any] = {
                "content": json.dumps(messages),
                "container_tags": [self.container_tag],
                "metadata": {"platform": "pipecat"},
            }
            if self.session_id:
                add_params["custom_id"] = self.session_id

            await self._supermemory_client.memories.add(**add_params)

        except Exception as e:
            logger.error(f"Error storing messages: {e}")

    def _enhance_context_with_memories(
        self,
        context: LLMContext,
        query: str,
        memories_data: Dict[str, Any],
    ) -> None:
        """Enhance LLM context with retrieved memories.

        Uses XML tags <user_memories>...</user_memories> to wrap memories,
        allowing replacement on each turn instead of accumulation.

        Args:
            context: The LLM context to enhance.
            query: The query used for retrieval.
            memories_data: Memory data from Supermemory API.
        """
        if self._last_query == query:
            return

        self._last_query = query

        profile = memories_data["profile"]
        deduplicated = deduplicate_memories(
            static=profile["static"],
            dynamic=profile["dynamic"],
            search_results=memories_data["search_results"],
        )

        total_memories = (
            len(deduplicated["static"])
            + len(deduplicated["dynamic"])
            + len(deduplicated["search_results"])
        )

        if total_memories == 0:
            return

        include_profile = self.params.mode in ("profile", "full")
        include_search = self.params.mode in ("query", "full")

        memory_text = format_memories_to_text(
            deduplicated,
            system_prompt=self.params.system_prompt,
            include_static=include_profile,
            include_dynamic=include_profile,
            include_search=include_search,
        )

        if not memory_text:
            return

        tagged_memory = f"{MEMORY_TAG_START}\n{memory_text}\n{MEMORY_TAG_END}"

        inject_to_system = self.params.inject_mode == "system" or (
            self.params.inject_mode == "auto" and self._audio_frames_detected
        )

        messages = context.get_messages()

        if inject_to_system:
            system_idx = None
            for i, msg in enumerate(messages):
                if msg.get("role") == "system":
                    system_idx = i
                    break

            if system_idx is not None:
                existing_content = messages[system_idx].get("content", "")
                if MEMORY_TAG_PATTERN.search(existing_content):
                    messages[system_idx]["content"] = MEMORY_TAG_PATTERN.sub(
                        tagged_memory, existing_content
                    )
                else:
                    messages[system_idx]["content"] = f"{existing_content}\n\n{tagged_memory}"
            else:
                messages.insert(0, {"role": "system", "content": tagged_memory})
        else:
            # Remove previous memory message if exists
            for i in range(len(messages) - 1, -1, -1):
                msg = messages[i]
                if msg.get("role") == "user" and MEMORY_TAG_START in msg.get("content", ""):
                    messages.pop(i)
                    break

            context.add_message({"role": "user", "content": tagged_memory})

    async def process_frame(self, frame: Frame, direction: FrameDirection) -> None:
        """Process frames, intercept context frames for memory integration."""
        await super().process_frame(frame, direction)

        # Auto-detect speech-to-speech mode via audio frames
        if isinstance(frame, InputAudioRawFrame):
            if not self._audio_frames_detected:
                self._audio_frames_detected = True
            await self.push_frame(frame, direction)
            return

        context = None
        messages = None

        if isinstance(frame, (LLMContextFrame, OpenAILLMContextFrame)):
            context = frame.context
        elif isinstance(frame, LLMMessagesFrame):
            messages = frame.messages
            context = LLMContext(messages)

        if context:
            try:
                context_messages = context.get_messages()
                latest_user_message = get_last_user_message(context_messages)

                if latest_user_message:
                    try:
                        memories_data = await self._retrieve_memories(latest_user_message)
                        self._enhance_context_with_memories(
                            context, latest_user_message, memories_data
                        )
                    except MemoryRetrievalError as e:
                        logger.warning(f"Memory retrieval failed: {e}")

                # Store unsent messages (user and assistant only)
                storable_messages = [
                    msg for msg in context_messages if msg["role"] in ("user", "assistant")
                ]
                unsent_messages = storable_messages[self._messages_sent_count :]

                if unsent_messages:
                    asyncio.create_task(self._store_messages(unsent_messages))
                    self._messages_sent_count = len(storable_messages)

                if messages is not None:
                    await self.push_frame(LLMMessagesFrame(context.get_messages()))
                else:
                    await self.push_frame(frame)

            except Exception as e:
                logger.error(f"Error processing frame: {e}")
                await self.push_frame(frame)
        else:
            await self.push_frame(frame, direction)

    def reset_memory_tracking(self) -> None:
        """Reset memory tracking state for a new conversation."""
        self._messages_sent_count = 0
        self._last_query = None
        self._audio_frames_detected = False