aboutsummaryrefslogtreecommitdiff
path: root/discord
diff options
context:
space:
mode:
Diffstat (limited to 'discord')
-rw-r--r--discord/guild.py16
-rw-r--r--discord/state.py55
-rw-r--r--discord/threads.py2
3 files changed, 67 insertions, 6 deletions
diff --git a/discord/guild.py b/discord/guild.py
index 4e7e38f6..e32c2fcf 100644
--- a/discord/guild.py
+++ b/discord/guild.py
@@ -26,7 +26,7 @@ from __future__ import annotations
import copy
from collections import namedtuple
-from typing import Dict, List, Literal, Optional, TYPE_CHECKING, Union, overload
+from typing import Dict, List, Set, Literal, Optional, TYPE_CHECKING, Union, overload
from . import utils, abc
from .role import Role
@@ -227,6 +227,20 @@ class Guild(Hashable):
def _remove_thread(self, thread):
self._threads.pop(thread.id, None)
+ def _clear_threads(self):
+ self._threads.clear()
+
+ def _remove_threads_by_channel(self, channel_id: int):
+ to_remove = [k for k, t in self._threads.items() if t.parent_id == channel_id]
+ for k in to_remove:
+ del self._threads[k]
+
+ def _filter_threads(self, channel_ids: Set[int]) -> Dict[int, Thread]:
+ to_remove: Dict[int, Thread] = {k: t for k, t in self._threads.items() if t.parent_id in channel_ids}
+ for k in to_remove:
+ del self._threads[k]
+ return to_remove
+
def __str__(self):
return self.name or ''
diff --git a/discord/state.py b/discord/state.py
index 85a94cee..d8c322ac 100644
--- a/discord/state.py
+++ b/discord/state.py
@@ -716,7 +716,7 @@ class ConnectionState:
thread = Thread(guild=guild, data=data)
guild._add_thread(thread)
- self.dispatch('thread_create', thread)
+ self.dispatch('thread_join', thread)
def parse_thread_update(self, data):
guild_id = int(data['guild_id'])
@@ -752,6 +752,16 @@ class ConnectionState:
log.debug('THREAD_LIST_SYNC referencing an unknown guild ID: %s. Discarding', guild_id)
return
+ try:
+ channel_ids = set(data['channel_ids'])
+ except KeyError:
+ # If not provided, then the entire guild is being synced
+ # So all previous thread data should be overwritten
+ previous_threads = guild._threads.copy()
+ guild._clear_threads()
+ else:
+ previous_threads = guild._filter_threads(channel_ids)
+
threads = {
d['id']: guild._store_thread(d)
for d in data.get('threads', [])
@@ -766,7 +776,13 @@ class ConnectionState:
else:
thread._add_member(ThreadMember(thread, member))
- # TODO: dispatch?
+ for thread in threads.values():
+ old = previous_threads.pop(thread.id, None)
+ if old is None:
+ self.dispatch('thread_join', thread)
+
+ for thread in previous_threads.values():
+ self.dispatch('thread_remove', thread)
def parse_thread_member_update(self, data):
guild_id = int(data['guild_id'])
@@ -776,15 +792,44 @@ class ConnectionState:
return
thread_id = int(data['id'])
- thread = guild.get_thread(thread_id)
+ thread: Optional[Thread] = guild.get_thread(thread_id)
if thread is None:
log.debug('THREAD_MEMBER_UPDATE referencing an unknown thread ID: %s. Discarding', thread_id)
return
member = ThreadMember(thread, data)
- thread._add_member(member)
+ thread.me = member
+
+ def parse_thread_members_update(self, data):
+ guild_id = int(data['guild_id'])
+ guild: Optional[Guild] = self._get_guild(guild_id)
+ if guild is None:
+ log.debug('THREAD_MEMBERS_UPDATE referencing an unknown guild ID: %s. Discarding', guild_id)
+ return
- # TODO: dispatch
+ thread_id = int(data['id'])
+ thread: Optional[Thread] = guild.get_thread(thread_id)
+ if thread is None:
+ log.debug('THREAD_MEMBERS_UPDATE referencing an unknown thread ID: %s. Discarding', thread_id)
+ return
+
+ added_members = [ThreadMember(thread, d) for d in data.get('added_members', [])]
+ removed_member_ids = data.get('removed_member_ids', [])
+ self_id = self.self_id
+ for member in added_members:
+ if member.id != self_id:
+ thread._add_member(member)
+ self.dispatch('thread_member_join', member)
+ else:
+ thread.me = member
+ self.dispatch('thread_join', thread)
+
+ for member_id in removed_member_ids:
+ if member_id != self_id:
+ member = thread._pop_member(member_id)
+ self.dispatch('thread_member_leave', member)
+ else:
+ self.dispatch('thread_remove', thread)
def parse_guild_member_add(self, data):
guild = self._get_guild(int(data['guild_id']))
diff --git a/discord/threads.py b/discord/threads.py
index 1a8a6af1..cf6d92aa 100644
--- a/discord/threads.py
+++ b/discord/threads.py
@@ -383,6 +383,8 @@ class Thread(Messageable, Hashable):
def _add_member(self, member: ThreadMember) -> None:
self._members[member.id] = member
+ def _pop_member(self, member_id: int) -> Optional[ThreadMember]:
+ return self._members.pop(member_id, None)
class ThreadMember(Hashable):
"""Represents a Discord thread member.