aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/umabot/bot.py12
-rw-r--r--src/umabot/rules/spam_detector.py50
-rw-r--r--tests/test_config.py4
-rw-r--r--tests/test_spam_detector.py74
4 files changed, 113 insertions, 27 deletions
diff --git a/src/umabot/bot.py b/src/umabot/bot.py
index 2519962..89b876a 100644
--- a/src/umabot/bot.py
+++ b/src/umabot/bot.py
@@ -2,8 +2,8 @@
import time
import threading
+from collections import OrderedDict
import praw
-from typing import List
from http.server import HTTPServer, BaseHTTPRequestHandler
from socketserver import ThreadingMixIn
from loguru import logger
@@ -93,7 +93,7 @@ class UmaBot:
]
# Track processed submissions to avoid processing old posts
- self.processed_submissions = set()
+ self.processed_submissions = OrderedDict()
self.initialized = False
self.logger.info(f"Bot initialized for r/{config.subreddit_name}")
@@ -140,7 +140,7 @@ class UmaBot:
if not self.initialized:
self.logger.info("Initializing bot - marking existing posts as processed")
for submission in new_submissions:
- self.processed_submissions.add(submission.id)
+ self.processed_submissions[submission.id] = None
self.initialized = True
self.logger.info(f"Bot initialized with {len(self.processed_submissions)} existing posts marked as processed")
return
@@ -150,7 +150,7 @@ class UmaBot:
for submission in new_submissions:
if submission.id not in self.processed_submissions:
truly_new_submissions.append(submission)
- self.processed_submissions.add(submission.id)
+ self.processed_submissions[submission.id] = None
if not truly_new_submissions:
self.logger.debug("No truly new submissions found")
@@ -224,6 +224,6 @@ class UmaBot:
# Keep only the last 1000 processed submissions
if len(self.processed_submissions) > 1000:
# Convert to list, keep last 1000, convert back to set
- submissions_list = list(self.processed_submissions)
- self.processed_submissions = set(submissions_list[-1000:])
+ while len(self.processed_submissions) > 1000:
+ self.processed_submissions.popitem(last=False)
self.logger.debug(f"Cleaned up processed submissions, keeping {len(self.processed_submissions)} most recent")
diff --git a/src/umabot/rules/spam_detector.py b/src/umabot/rules/spam_detector.py
index 2de48f2..861616c 100644
--- a/src/umabot/rules/spam_detector.py
+++ b/src/umabot/rules/spam_detector.py
@@ -1,8 +1,7 @@
"""Spam detection rule for limiting posts per user per day."""
-import time
from datetime import datetime, timedelta, timezone
-from typing import Dict, List
+from typing import Dict, Set
import praw.models
from .base import Rule
@@ -13,7 +12,7 @@ class SpamDetector(Rule):
def __init__(self, config):
"""Initialize the spam detector."""
super().__init__(config)
- self.user_posts: Dict[str, List[tuple[float, str]]] = {} # (timestamp, post_id)
+ self.user_posts: Dict[str, Dict[str, float]] = {}
self.max_posts = config.max_posts_per_day
def should_remove(self, submission: praw.models.Submission) -> bool:
@@ -23,22 +22,23 @@ class SpamDetector(Rule):
username = submission.author.name
current_utc = datetime.now(timezone.utc)
+ submission_utc = self._get_submission_utc(submission)
- # Clean old posts from tracking (remove posts from previous days)
self._clean_old_posts(username, current_utc)
+
+ if submission_utc.date() != current_utc.date():
+ return False
- # Count current active posts in today's UTC day
if username not in self.user_posts:
- self.user_posts[username] = []
+ self.user_posts[username] = {}
- # Filter out removed posts and count active ones
active_posts = self._get_active_posts(username, current_utc)
- post_count = len(active_posts)
-
- # Add current post to tracking
- self.user_posts[username].append((current_utc.timestamp(), submission.id))
+ if submission.id in active_posts:
+ return False
- # Check if this post exceeds the limit
+ post_count = len(active_posts)
+ self.user_posts[username][submission.id] = submission_utc.timestamp()
+
if post_count >= self.max_posts:
self.logger.info(
f"User {username} has posted {post_count + 1} active times today (UTC) "
@@ -81,28 +81,36 @@ class SpamDetector(Rule):
today_timestamp = today_start.timestamp()
# Keep only posts from today
- self.user_posts[username] = [
- (post_time, post_id) for post_time, post_id in self.user_posts[username]
+ self.user_posts[username] = {
+ post_id: post_time
+ for post_id, post_time in self.user_posts[username].items()
if post_time >= today_timestamp
- ]
+ }
- def _get_active_posts(self, username: str, current_utc: datetime) -> List[tuple[float, str]]:
+ def _get_active_posts(self, username: str, current_utc: datetime) -> Set[str]:
"""Get active (non-removed) posts for a user."""
if username not in self.user_posts:
- return []
+ return set()
# Get start of current UTC day
today_start = current_utc.replace(hour=0, minute=0, second=0, microsecond=0)
today_timestamp = today_start.timestamp()
- active_posts = []
- for post_time, post_id in self.user_posts[username]:
+ active_posts = set()
+ for post_id, post_time in self.user_posts[username].items():
if post_time >= today_timestamp:
- # Check if the post is still active (not removed)
if self._is_post_active(post_id):
- active_posts.append((post_time, post_id))
+ active_posts.add(post_id)
return active_posts
+
+ def _get_submission_utc(self, submission: praw.models.Submission) -> datetime:
+ """Get the submission timestamp in UTC."""
+ created_utc = getattr(submission, "created_utc", None)
+ if created_utc is None:
+ return datetime.now(timezone.utc)
+
+ return datetime.fromtimestamp(created_utc, timezone.utc)
def _is_post_active(self, post_id: str) -> bool:
"""Check if a post is still active (not removed) by checking its status."""
diff --git a/tests/test_config.py b/tests/test_config.py
index a160a84..71879c1 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -14,8 +14,10 @@ def test_config_from_env():
"REDDIT_CLIENT_SECRET": "test_client_secret",
"REDDIT_USERNAME": "test_username",
"REDDIT_PASSWORD": "test_password",
+ "OPENAI_API_KEY": "test_openai_key",
"SUBREDDIT_NAME": "test_subreddit",
"ROLEPLAY_MESSAGE": "Test roleplay message",
+ "DRY_RUN": "false",
}
with patch.dict(os.environ, test_env):
@@ -40,6 +42,7 @@ def test_config_validation():
username="",
password="",
user_agent="test",
+ openai_api_key="",
subreddit_name="",
roleplay_message=""
)
@@ -56,6 +59,7 @@ def test_config_validation_success():
username="test",
password="test",
user_agent="test",
+ openai_api_key="test",
subreddit_name="test",
roleplay_message="test"
)
diff --git a/tests/test_spam_detector.py b/tests/test_spam_detector.py
new file mode 100644
index 0000000..4ac8591
--- /dev/null
+++ b/tests/test_spam_detector.py
@@ -0,0 +1,74 @@
+"""Tests for spam detection rule behavior."""
+
+import time
+from collections import OrderedDict
+from types import SimpleNamespace
+from typing import Optional
+from unittest.mock import MagicMock
+
+from umabot.bot import UmaBot
+from umabot.rules.spam_detector import SpamDetector
+
+
+def make_submission(post_id: str, username: str, created_utc: Optional[float] = None):
+ """Create a submission-like test object."""
+ return SimpleNamespace(
+ id=post_id,
+ author=SimpleNamespace(name=username),
+ created_utc=created_utc if created_utc is not None else time.time(),
+ )
+
+
+def make_config(max_posts_per_day: int = 3):
+ """Create a minimal config-like object for tests."""
+ return SimpleNamespace(
+ max_posts_per_day=max_posts_per_day,
+ subreddit_name="okbuddyumamusume",
+ )
+
+
+def test_spam_detector_removes_fourth_unique_post_same_day():
+ """Fourth unique post in the same UTC day should be removed."""
+ detector = SpamDetector(make_config())
+
+ assert detector.should_remove(make_submission("p1", "alice")) is False
+ assert detector.should_remove(make_submission("p2", "alice")) is False
+ assert detector.should_remove(make_submission("p3", "alice")) is False
+ assert detector.should_remove(make_submission("p4", "alice")) is True
+
+
+def test_spam_detector_does_not_double_count_same_submission():
+ """Seeing the same submission again should not increment the user's count."""
+ detector = SpamDetector(make_config())
+ submission = make_submission("p1", "alice")
+
+ assert detector.should_remove(submission) is False
+ assert detector.should_remove(submission) is False
+ assert detector.should_remove(make_submission("p2", "alice")) is False
+ assert detector.should_remove(make_submission("p3", "alice")) is False
+ assert detector.should_remove(make_submission("p4", "alice")) is True
+
+
+def test_spam_detector_ignores_old_submission_seen_today():
+ """Old submissions should not be treated as today's posts."""
+ detector = SpamDetector(make_config())
+ two_days_ago = time.time() - (2 * 24 * 60 * 60)
+
+ assert detector.should_remove(make_submission("old", "alice", two_days_ago)) is False
+ assert detector.should_remove(make_submission("p1", "alice")) is False
+ assert detector.should_remove(make_submission("p2", "alice")) is False
+ assert detector.should_remove(make_submission("p3", "alice")) is False
+
+
+def test_cleanup_processed_submissions_prunes_oldest_entries():
+ """Processed submission cleanup should keep the most recently seen IDs."""
+ bot = UmaBot.__new__(UmaBot)
+ bot.logger = MagicMock()
+ bot.processed_submissions = OrderedDict((str(i), None) for i in range(1005))
+
+ UmaBot._cleanup_processed_submissions(bot)
+
+ assert len(bot.processed_submissions) == 1000
+ assert "0" not in bot.processed_submissions
+ assert "4" not in bot.processed_submissions
+ assert "5" in bot.processed_submissions