aboutsummaryrefslogtreecommitdiff
path: root/packages/openai-sdk-python/src/supermemory_openai/middleware.py
blob: 4f6dc8ec09e055bad236a13a2e40d14cc238a903 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
"""Supermemory middleware for OpenAI clients."""

import asyncio
import os
from dataclasses import dataclass
from typing import Any, Literal, Optional, Union, cast

import supermemory
from openai import AsyncOpenAI, OpenAI
from openai.types.chat import (
    ChatCompletionMessageParam,
    ChatCompletionSystemMessageParam,
)

from .exceptions import (
    SupermemoryAPIError,
    SupermemoryConfigurationError,
    SupermemoryMemoryOperationError,
    SupermemoryNetworkError,
)
from .utils import (
    Logger,
    convert_profile_to_markdown,
    create_logger,
    deduplicate_memories,
    get_conversation_content,
    get_last_user_message,
)


@dataclass
class OpenAIMiddlewareOptions:
    """Configuration options for OpenAI middleware."""

    conversation_id: Optional[str] = None
    verbose: bool = False
    mode: Literal["profile", "query", "full"] = "profile"
    add_memory: Literal["always", "never"] = "never"


class SupermemoryProfileSearch:
    """Type for Supermemory profile search response."""

    def __init__(self, data: dict[str, Any]):
        self.profile: dict[str, Any] = data.get("profile", {})
        self.search_results: dict[str, Any] = data.get("searchResults", {})


async def supermemory_profile_search(
    container_tag: str,
    query_text: str,
    api_key: str,
) -> SupermemoryProfileSearch:
    """Search for memories using the SuperMemory profile API."""
    payload = {
        "containerTag": container_tag,
    }
    if query_text:
        payload["q"] = query_text

    try:
        import aiohttp

        async with aiohttp.ClientSession() as session:
            async with session.post(
                "https://api.supermemory.ai/v4/profile",
                headers={
                    "Content-Type": "application/json",
                    "Authorization": f"Bearer {api_key}",
                },
                json=payload,
            ) as response:
                if not response.ok:
                    error_text = await response.text()
                    raise SupermemoryAPIError(
                        "Supermemory profile search failed",
                        status_code=response.status,
                        response_text=error_text,
                    )

                data = await response.json()
                return SupermemoryProfileSearch(data)

    except ImportError:
        # Fallback to requests if aiohttp not available
        import requests

        response = requests.post(
            "https://api.supermemory.ai/v4/profile",
            headers={
                "Content-Type": "application/json",
                "Authorization": f"Bearer {api_key}",
            },
            json=payload,
        )

        if not response.ok:
            raise SupermemoryAPIError(
                "Supermemory profile search failed",
                status_code=response.status_code,
                response_text=response.text,
            )

        return SupermemoryProfileSearch(response.json())


async def add_system_prompt(
    messages: list[ChatCompletionMessageParam],
    container_tag: str,
    logger: Logger,
    mode: Literal["profile", "query", "full"],
    api_key: str,
) -> list[ChatCompletionMessageParam]:
    """Add memory-enhanced system prompts to chat completion messages."""
    system_prompt_exists = any(msg.get("role") == "system" for msg in messages)

    query_text = get_last_user_message(messages) if mode != "profile" else ""

    memories_response = await supermemory_profile_search(
        container_tag, query_text, api_key
    )

    profile = memories_response.profile or {}
    search_results_data = memories_response.search_results or {}
    memory_count_static = len(profile.get("static", []))
    memory_count_dynamic = len(profile.get("dynamic", []))
    memory_count_search = len(search_results_data.get("results", []))

    logger.info(
        "Memory search completed",
        {
            "container_tag": container_tag,
            "memory_count_static": memory_count_static,
            "memory_count_dynamic": memory_count_dynamic,
            "query_text": query_text[:100] + ("..." if len(query_text) > 100 else ""),
            "mode": mode,
        },
    )

    deduplicated = deduplicate_memories(
        static=profile.get("static", []),
        dynamic=profile.get("dynamic", []),
        search_results=search_results_data.get("results", []),
    )

    logger.debug(
        "Memory deduplication completed",
        {
            "static": {
                "original": memory_count_static,
                "deduplicated": len(deduplicated.static),
            },
            "dynamic": {
                "original": memory_count_dynamic,
                "deduplicated": len(deduplicated.dynamic),
            },
            "search_results": {
                "original": memory_count_search,
                "deduplicated": len(deduplicated.search_results),
            },
        },
    )

    profile_data = ""
    if mode != "query":
        profile_data = convert_profile_to_markdown(
            {
                "profile": {
                    "static": deduplicated.static,
                    "dynamic": deduplicated.dynamic,
                },
                "searchResults": {"results": []},
            }
        )

    search_results_memories = ""
    if mode != "profile" and deduplicated.search_results:
        search_results_memories = (
            "Search results for user's recent message: \n"
            + "\n".join(f"- {memory}" for memory in deduplicated.search_results)
        )

    memories = f"{profile_data}\n{search_results_memories}".strip()

    if memories:
        logger.debug(
            "Memory content preview",
            {
                "content": memories,
                "full_length": len(memories),
            },
        )

    if system_prompt_exists:
        logger.debug("Added memories to existing system prompt")
        return [
            {**msg, "content": f"{msg.get('content', '')} \n {memories}"}
            if msg.get("role") == "system"
            else msg
            for msg in messages
        ]

    logger.debug("System prompt does not exist, created system prompt with memories")
    system_message: ChatCompletionSystemMessageParam = {
        "role": "system",
        "content": memories,
    }
    return [system_message] + messages


async def add_memory_tool(
    client: supermemory.Supermemory,
    container_tag: str,
    content: str,
    custom_id: Optional[str],
    logger: Logger,
) -> None:
    """Add a new memory to the SuperMemory system."""
    try:
        add_params = {
            "content": content,
            "container_tags": [container_tag],
        }
        if custom_id is not None:
            add_params["custom_id"] = custom_id

        # Handle both sync and async supermemory clients
        try:
            response = await client.add(**add_params)
        except TypeError:
            # If it's not awaitable, call it synchronously
            response = client.add(**add_params)

        logger.info(
            "Memory saved successfully",
            {
                "container_tag": container_tag,
                "custom_id": custom_id,
                "content_length": len(content),
                "memory_id": response.id,
            },
        )
    except (OSError, ConnectionError) as network_error:
        logger.error(
            "Network error while saving memory",
            {"error": str(network_error)},
        )
        raise SupermemoryNetworkError(
            "Failed to save memory due to network error", network_error
        )
    except Exception as error:
        logger.error(
            "Error saving memory",
            {"error": str(error)},
        )
        raise SupermemoryMemoryOperationError("Failed to save memory", error)


class SupermemoryOpenAIWrapper:
    """Wrapper for OpenAI client with Supermemory middleware."""

    def __init__(
        self,
        openai_client: Union[OpenAI, AsyncOpenAI],
        container_tag: str,
        options: Optional[OpenAIMiddlewareOptions] = None,
    ):
        self._client: Union[OpenAI, AsyncOpenAI] = openai_client
        self._container_tag: str = container_tag
        self._options: OpenAIMiddlewareOptions = options or OpenAIMiddlewareOptions()
        self._logger: Logger = create_logger(self._options.verbose)

        # Track background tasks to ensure they complete
        self._background_tasks: set[asyncio.Task] = set()

        if not hasattr(supermemory, "Supermemory"):
            raise SupermemoryConfigurationError(
                "supermemory package is required but not found",
                ImportError("supermemory package not installed"),
            )

        api_key = self._get_api_key()
        try:
            self._supermemory_client: supermemory.Supermemory = supermemory.Supermemory(
                api_key=api_key
            )
        except Exception as e:
            raise SupermemoryConfigurationError(
                f"Failed to initialize Supermemory client: {e}", e
            )

        # Wrap the chat completions create method
        self._wrap_chat_completions()

    def _get_api_key(self) -> str:
        """Get Supermemory API key from environment."""
        import os

        api_key = os.getenv("SUPERMEMORY_API_KEY")
        if not api_key:
            raise SupermemoryConfigurationError(
                "SUPERMEMORY_API_KEY environment variable is required but not set"
            )
        return api_key

    def _wrap_chat_completions(self) -> None:
        """Wrap the chat completions create method with memory injection."""
        original_create = self._client.chat.completions.create

        if asyncio.iscoroutinefunction(original_create):

            async def create_with_memory(
                **kwargs: Any,
            ) -> Any:
                return await self._create_with_memory_async(original_create, **kwargs)
        else:

            def create_with_memory(
                **kwargs: Any,
            ) -> Any:
                return self._create_with_memory_sync(original_create, **kwargs)

        # Replace the create method with our wrapper
        setattr(self._client.chat.completions, "create", create_with_memory)

    async def _create_with_memory_async(
        self,
        original_create: Any,
        **kwargs: Any,
    ) -> Any:
        """Async version of create with memory injection."""
        messages = kwargs.get("messages", [])

        if self._options.add_memory == "always":
            user_message = get_last_user_message(messages)
            if user_message and user_message.strip():
                content = (
                    get_conversation_content(messages)
                    if self._options.conversation_id
                    else user_message
                )
                custom_id = (
                    f"conversation:{self._options.conversation_id}"
                    if self._options.conversation_id
                    else None
                )

                # Create background task for memory storage
                task = asyncio.create_task(
                    add_memory_tool(
                        self._supermemory_client,
                        self._container_tag,
                        content,
                        custom_id,
                        self._logger,
                    )
                )

                # Track the task and set up cleanup
                self._background_tasks.add(task)
                task.add_done_callback(self._background_tasks.discard)

                # Log any exceptions but don't fail the main request
                def handle_task_exception(task_obj):
                    try:
                        if task_obj.exception() is not None:
                            exception = task_obj.exception()
                            if isinstance(
                                exception,
                                (SupermemoryNetworkError, SupermemoryAPIError),
                            ):
                                self._logger.warn(
                                    "Background memory storage failed",
                                    {
                                        "error": str(exception),
                                        "type": type(exception).__name__,
                                    },
                                )
                            else:
                                self._logger.error(
                                    "Unexpected error in background memory storage",
                                    {
                                        "error": str(exception),
                                        "type": type(exception).__name__,
                                    },
                                )
                    except asyncio.CancelledError:
                        self._logger.debug("Memory storage task was cancelled")

                task.add_done_callback(handle_task_exception)

        if self._options.mode != "profile":
            user_message = get_last_user_message(messages)
            if not user_message:
                self._logger.debug("No user message found, skipping memory search")
                return await original_create(**kwargs)

        self._logger.info(
            "Starting memory search",
            {
                "container_tag": self._container_tag,
                "conversation_id": self._options.conversation_id,
                "mode": self._options.mode,
            },
        )

        enhanced_messages = await add_system_prompt(
            messages,
            self._container_tag,
            self._logger,
            self._options.mode,
            self._get_api_key(),
        )

        kwargs["messages"] = enhanced_messages
        return await original_create(**kwargs)

    def _create_with_memory_sync(
        self,
        original_create: Any,
        **kwargs: Any,
    ) -> Any:
        """Sync version of create with memory injection."""
        # For sync clients, we implement a simplified version without background tasks
        messages = kwargs.get("messages", [])

        # Handle memory addition synchronously if needed
        if self._options.add_memory == "always":
            user_message = get_last_user_message(messages)
            if user_message and user_message.strip():
                content = (
                    get_conversation_content(messages)
                    if self._options.conversation_id
                    else user_message
                )
                custom_id = (
                    f"conversation:{self._options.conversation_id}"
                    if self._options.conversation_id
                    else None
                )

                # Use asyncio.run() for the memory addition
                try:
                    asyncio.run(
                        add_memory_tool(
                            self._supermemory_client,
                            self._container_tag,
                            content,
                            custom_id,
                            self._logger,
                        )
                    )
                except RuntimeError as e:
                    if "cannot be called from a running event loop" in str(e):
                        # We're in an async context, log warning and skip memory saving
                        self._logger.warn(
                            "Cannot save memory in sync client from async context",
                            {"error": str(e)},
                        )
                    else:
                        raise
                except SupermemoryNetworkError as e:
                    # Network errors are expected, log as warning
                    self._logger.warn("Network error saving memory", {"error": str(e)})
                except (SupermemoryAPIError, SupermemoryMemoryOperationError) as e:
                    # API/memory errors are concerning, log as error
                    self._logger.error("Failed to save memory", {"error": str(e)})
                except Exception as e:
                    # Unexpected errors should be investigated
                    self._logger.error(
                        "Unexpected error saving memory",
                        {"error": str(e), "type": type(e).__name__},
                    )

        # Handle memory search and injection
        if self._options.mode != "profile":
            user_message = get_last_user_message(messages)
            if not user_message:
                self._logger.debug("No user message found, skipping memory search")
                return original_create(**kwargs)

        self._logger.info(
            "Starting memory search",
            {
                "container_tag": self._container_tag,
                "conversation_id": self._options.conversation_id,
                "mode": self._options.mode,
            },
        )

        # Use asyncio.run() for memory search and injection
        try:
            enhanced_messages = asyncio.run(
                add_system_prompt(
                    messages,
                    self._container_tag,
                    self._logger,
                    self._options.mode,
                    self._get_api_key(),
                )
            )
        except RuntimeError as e:
            if "cannot be called from a running event loop" in str(e):
                # We're in an async context, run in a separate thread
                import concurrent.futures

                with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
                    future = executor.submit(
                        asyncio.run,
                        add_system_prompt(
                            messages,
                            self._container_tag,
                            self._logger,
                            self._options.mode,
                            self._get_api_key(),
                        ),
                    )
                    enhanced_messages = future.result()
            else:
                raise

        kwargs["messages"] = enhanced_messages
        return original_create(**kwargs)

    async def wait_for_background_tasks(self, timeout: Optional[float] = 10.0) -> None:
        """
        Wait for all background memory storage tasks to complete.

        Args:
            timeout: Maximum time to wait in seconds. None for no timeout.

        Raises:
            asyncio.TimeoutError: If tasks don't complete within timeout
        """
        if not self._background_tasks:
            return

        self._logger.debug(
            f"Waiting for {len(self._background_tasks)} background tasks to complete"
        )

        try:
            if timeout is not None:
                await asyncio.wait_for(
                    asyncio.gather(*self._background_tasks, return_exceptions=True),
                    timeout=timeout,
                )
            else:
                await asyncio.gather(*self._background_tasks, return_exceptions=True)

            self._logger.debug("All background tasks completed")
        except asyncio.TimeoutError:
            self._logger.warn(
                f"Background tasks did not complete within {timeout}s timeout"
            )
            # Cancel remaining tasks
            for task in self._background_tasks:
                if not task.done():
                    task.cancel()
            raise

    def cancel_background_tasks(self) -> None:
        """Cancel all pending background tasks."""
        cancelled_count = 0
        for task in self._background_tasks:
            if not task.done():
                task.cancel()
                cancelled_count += 1

        if cancelled_count > 0:
            self._logger.debug(f"Cancelled {cancelled_count} pending background tasks")

    async def __aenter__(self):
        """Async context manager entry."""
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Async context manager exit - wait for background tasks."""
        try:
            await self.wait_for_background_tasks(timeout=5.0)
        except asyncio.TimeoutError:
            self._logger.warn("Some background memory tasks did not complete on exit")

    def __enter__(self):
        """Sync context manager entry."""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Sync context manager exit - attempt to wait for background tasks."""
        if self._background_tasks:
            try:
                # Try to wait for background tasks in sync context
                asyncio.run(self.wait_for_background_tasks(timeout=5.0))
            except RuntimeError as e:
                if "cannot be called from a running event loop" in str(e):
                    # In async context, just cancel the tasks
                    self._logger.warn(
                        "Cannot wait for background tasks in sync context from async environment. "
                        "Use async context manager or call wait_for_background_tasks() manually."
                    )
                    self.cancel_background_tasks()
                else:
                    raise
            except asyncio.TimeoutError:
                self._logger.warn(
                    "Some background memory tasks did not complete on exit"
                )
                self.cancel_background_tasks()

    def __getattr__(self, name: str) -> Any:
        """Delegate all other attributes to the wrapped client."""
        return getattr(self._client, name)


def with_supermemory(
    openai_client: Union[OpenAI, AsyncOpenAI],
    container_tag: str,
    options: Optional[OpenAIMiddlewareOptions] = None,
) -> Union[OpenAI, AsyncOpenAI]:
    """
    Wraps an OpenAI client with SuperMemory middleware to automatically inject relevant memories
    into the system prompt based on the user's message content.

    This middleware searches the supermemory API for relevant memories using the container tag
    and user message, then either appends memories to an existing system prompt or creates
    a new system prompt with the memories.

    Args:
        openai_client: The OpenAI client to wrap with SuperMemory middleware
        container_tag: The container tag/identifier for memory search (e.g., user ID, project ID)
        options: Optional configuration options for the middleware

    Returns:
        An OpenAI client with SuperMemory middleware injected

    Example:
        ```python
        from supermemory_openai import with_supermemory, OpenAIMiddlewareOptions
        from openai import OpenAI

        # Create OpenAI client with supermemory middleware
        openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
        openai_with_supermemory = with_supermemory(
            openai,
            "user-123",
            OpenAIMiddlewareOptions(
                conversation_id="conversation-456",
                mode="full",
                add_memory="always"
            )
        )

        # Use normally - memories will be automatically injected
        response = await openai_with_supermemory.chat.completions.create(
            model="gpt-4",
            messages=[
                {"role": "user", "content": "What's my favorite programming language?"}
            ]
        )
        ```

    Raises:
        ValueError: When SUPERMEMORY_API_KEY environment variable is not set
        Exception: When supermemory API request fails
    """
    wrapper = SupermemoryOpenAIWrapper(openai_client, container_tag, options)
    # Return the wrapper, which delegates all attributes to the original client
    return cast(Union[OpenAI, AsyncOpenAI], wrapper)