aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRapptz <[email protected]>2021-05-09 22:23:21 -0400
committerRapptz <[email protected]>2021-06-08 07:29:17 -0400
commitbd369c76ea5424f65e37d84bfb45df1c76a4e739 (patch)
tree435f8a36703721a4ef74ef0a4682dd5dd8655a3c
parentAdd ThreadMember.thread (diff)
downloaddiscord.py-bd369c76ea5424f65e37d84bfb45df1c76a4e739.tar.xz
discord.py-bd369c76ea5424f65e37d84bfb45df1c76a4e739.zip
Parse remaining thread events.
-rw-r--r--discord/guild.py16
-rw-r--r--discord/state.py55
-rw-r--r--discord/threads.py2
-rw-r--r--docs/api.rst50
4 files changed, 114 insertions, 9 deletions
diff --git a/discord/guild.py b/discord/guild.py
index 4e7e38f6..e32c2fcf 100644
--- a/discord/guild.py
+++ b/discord/guild.py
@@ -26,7 +26,7 @@ from __future__ import annotations
import copy
from collections import namedtuple
-from typing import Dict, List, Literal, Optional, TYPE_CHECKING, Union, overload
+from typing import Dict, List, Set, Literal, Optional, TYPE_CHECKING, Union, overload
from . import utils, abc
from .role import Role
@@ -227,6 +227,20 @@ class Guild(Hashable):
def _remove_thread(self, thread):
self._threads.pop(thread.id, None)
+ def _clear_threads(self):
+ self._threads.clear()
+
+ def _remove_threads_by_channel(self, channel_id: int):
+ to_remove = [k for k, t in self._threads.items() if t.parent_id == channel_id]
+ for k in to_remove:
+ del self._threads[k]
+
+ def _filter_threads(self, channel_ids: Set[int]) -> Dict[int, Thread]:
+ to_remove: Dict[int, Thread] = {k: t for k, t in self._threads.items() if t.parent_id in channel_ids}
+ for k in to_remove:
+ del self._threads[k]
+ return to_remove
+
def __str__(self):
return self.name or ''
diff --git a/discord/state.py b/discord/state.py
index 85a94cee..d8c322ac 100644
--- a/discord/state.py
+++ b/discord/state.py
@@ -716,7 +716,7 @@ class ConnectionState:
thread = Thread(guild=guild, data=data)
guild._add_thread(thread)
- self.dispatch('thread_create', thread)
+ self.dispatch('thread_join', thread)
def parse_thread_update(self, data):
guild_id = int(data['guild_id'])
@@ -752,6 +752,16 @@ class ConnectionState:
log.debug('THREAD_LIST_SYNC referencing an unknown guild ID: %s. Discarding', guild_id)
return
+ try:
+ channel_ids = set(data['channel_ids'])
+ except KeyError:
+ # If not provided, then the entire guild is being synced
+ # So all previous thread data should be overwritten
+ previous_threads = guild._threads.copy()
+ guild._clear_threads()
+ else:
+ previous_threads = guild._filter_threads(channel_ids)
+
threads = {
d['id']: guild._store_thread(d)
for d in data.get('threads', [])
@@ -766,7 +776,13 @@ class ConnectionState:
else:
thread._add_member(ThreadMember(thread, member))
- # TODO: dispatch?
+ for thread in threads.values():
+ old = previous_threads.pop(thread.id, None)
+ if old is None:
+ self.dispatch('thread_join', thread)
+
+ for thread in previous_threads.values():
+ self.dispatch('thread_remove', thread)
def parse_thread_member_update(self, data):
guild_id = int(data['guild_id'])
@@ -776,15 +792,44 @@ class ConnectionState:
return
thread_id = int(data['id'])
- thread = guild.get_thread(thread_id)
+ thread: Optional[Thread] = guild.get_thread(thread_id)
if thread is None:
log.debug('THREAD_MEMBER_UPDATE referencing an unknown thread ID: %s. Discarding', thread_id)
return
member = ThreadMember(thread, data)
- thread._add_member(member)
+ thread.me = member
+
+ def parse_thread_members_update(self, data):
+ guild_id = int(data['guild_id'])
+ guild: Optional[Guild] = self._get_guild(guild_id)
+ if guild is None:
+ log.debug('THREAD_MEMBERS_UPDATE referencing an unknown guild ID: %s. Discarding', guild_id)
+ return
- # TODO: dispatch
+ thread_id = int(data['id'])
+ thread: Optional[Thread] = guild.get_thread(thread_id)
+ if thread is None:
+ log.debug('THREAD_MEMBERS_UPDATE referencing an unknown thread ID: %s. Discarding', thread_id)
+ return
+
+ added_members = [ThreadMember(thread, d) for d in data.get('added_members', [])]
+ removed_member_ids = data.get('removed_member_ids', [])
+ self_id = self.self_id
+ for member in added_members:
+ if member.id != self_id:
+ thread._add_member(member)
+ self.dispatch('thread_member_join', member)
+ else:
+ thread.me = member
+ self.dispatch('thread_join', thread)
+
+ for member_id in removed_member_ids:
+ if member_id != self_id:
+ member = thread._pop_member(member_id)
+ self.dispatch('thread_member_leave', member)
+ else:
+ self.dispatch('thread_remove', thread)
def parse_guild_member_add(self, data):
guild = self._get_guild(int(data['guild_id']))
diff --git a/discord/threads.py b/discord/threads.py
index 1a8a6af1..cf6d92aa 100644
--- a/discord/threads.py
+++ b/discord/threads.py
@@ -383,6 +383,8 @@ class Thread(Messageable, Hashable):
def _add_member(self, member: ThreadMember) -> None:
self._members[member.id] = member
+ def _pop_member(self, member_id: int) -> Optional[ThreadMember]:
+ return self._members.pop(member_id, None)
class ThreadMember(Hashable):
"""Represents a Discord thread member.
diff --git a/docs/api.rst b/docs/api.rst
index 4282373e..395de477 100644
--- a/docs/api.rst
+++ b/docs/api.rst
@@ -658,10 +658,42 @@ to handle it, which defaults to print a traceback and ignoring the exception.
:param last_pin: The latest message that was pinned as an aware datetime in UTC. Could be ``None``.
:type last_pin: Optional[:class:`datetime.datetime`]
+.. function:: on_thread_join(thread)
+
+ Called whenever a thread is joined.
+
+ Note that you can get the guild from :attr:`Thread.guild`.
+
+ This requires :attr:`Intents.guilds` to be enabled.
+
+ .. versionadded:: 2.0
+
+ :param thread: The thread that got joined.
+ :type thread: :class:`Thread`
+
+.. function:: on_thread_remove(thread)
+
+ Called whenever a thread is removed. This is different from a thread being deleted.
+
+ Note that you can get the guild from :attr:`Thread.guild`.
+
+ This requires :attr:`Intents.guilds` to be enabled.
+
+ .. warning::
+
+ Due to technical limitations, this event might not be called
+ as soon as one expects. Since the library tracks thread membership
+ locally, the API only sends updated thread membership status upon being
+ synced by joining a thread.
+
+ .. versionadded:: 2.0
+
+ :param thread: The thread that got removed.
+ :type thread: :class:`Thread`
+
.. function:: on_thread_delete(thread)
- on_thread_create(thread)
- Called whenever a thread is deleted or created.
+ Called whenever a thread is deleted.
Note that you can get the guild from :attr:`Thread.guild`.
@@ -669,9 +701,21 @@ to handle it, which defaults to print a traceback and ignoring the exception.
.. versionadded:: 2.0
- :param thread: The thread that got created or deleted.
+ :param thread: The thread that got deleted.
:type thread: :class:`Thread`
+.. function:: on_thread_member_join(member)
+ on_thread_member_remove(member)
+
+ Called when a :class:`ThreadMember` leaves or joins a :class:`Thread`.
+
+ You can get the thread a member belongs in by accessing :attr:`ThreadMember.thread`.
+
+ This requires :attr:`Intents.members` to be enabled.
+
+ :param member: The member who joined or left.
+ :type member: :class:`ThreadMember`
+
.. function:: on_thread_update(before, after)
Called whenever a thread is updated.