aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Nørgaard <[email protected]>2021-07-04 02:35:31 +0100
committerGitHub <[email protected]>2021-07-03 21:35:31 -0400
commitd1dc41ec2f6e0e6e6489151f33a39b7fccfc8f6f (patch)
tree1f3e91c728e87df04dcb32fcda5927af719004ed
parentFix ui.Button constructor default style to match the decorator (diff)
downloaddiscord.py-d1dc41ec2f6e0e6e6489151f33a39b7fccfc8f6f.tar.xz
discord.py-d1dc41ec2f6e0e6e6489151f33a39b7fccfc8f6f.zip
Fix Client.fetch_channel not returning Thread
-rw-r--r--discord/channel.py10
-rw-r--r--discord/client.py11
-rw-r--r--discord/guild.py4
-rw-r--r--discord/iterators.py2
-rw-r--r--discord/message.py2
-rw-r--r--discord/state.py4
-rw-r--r--discord/threads.py4
7 files changed, 22 insertions, 15 deletions
diff --git a/discord/channel.py b/discord/channel.py
index 5a213fe1..dbe8533f 100644
--- a/discord/channel.py
+++ b/discord/channel.py
@@ -691,7 +691,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
type=ChannelType.public_thread.value,
)
- return Thread(guild=self.guild, data=data)
+ return Thread(guild=self.guild, state=self._state, data=data)
def archived_threads(
self,
@@ -753,7 +753,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
"""
data = await self._state.http.get_active_threads(self.id)
# TODO: thread members?
- return [Thread(guild=self.guild, data=d) for d in data.get('threads', [])]
+ return [Thread(guild=self.guild, state=self._state, data=d) for d in data.get('threads', [])]
class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable):
@@ -1924,3 +1924,9 @@ def _channel_factory(channel_type: Union[ChannelType, int]):
return GroupChannel, value
else:
return cls, value
+
+def _threaded_channel_factory(channel_type: Union[ChannelType, int]):
+ cls, value = _channel_factory(channel_type)
+ if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread):
+ return Thread, value
+ return cls, value
diff --git a/discord/client.py b/discord/client.py
index 2b3c3e17..24ceb31b 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -39,7 +39,7 @@ from .template import Template
from .widget import Widget
from .guild import Guild
from .emoji import Emoji
-from .channel import _channel_factory
+from .channel import _threaded_channel_factory
from .enums import ChannelType
from .mentions import AllowedMentions
from .errors import *
@@ -58,6 +58,7 @@ from .iterators import GuildIterator
from .appinfo import AppInfo
from .ui.view import View
from .stage_instance import StageInstance
+from .threads import Thread
if TYPE_CHECKING:
from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake
@@ -1371,10 +1372,10 @@ class Client:
data = await self.http.get_user(user_id)
return User(state=self._connection, data=data)
- async def fetch_channel(self, channel_id: int) -> Union[GuildChannel, PrivateChannel]:
+ async def fetch_channel(self, channel_id: int) -> Union[GuildChannel, PrivateChannel, Thread]:
"""|coro|
- Retrieves a :class:`.abc.GuildChannel` or :class:`.abc.PrivateChannel` with the specified ID.
+ Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID.
.. note::
@@ -1395,12 +1396,12 @@ class Client:
Returns
--------
- Union[:class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`]
+ Union[:class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, :class:`.Thread`]
The channel from the ID.
"""
data = await self.http.get_channel(channel_id)
- factory, ch_type = _channel_factory(data['type'])
+ factory, ch_type = _threaded_channel_factory(data['type'])
if factory is None:
raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data))
diff --git a/discord/guild.py b/discord/guild.py
index 54b86b29..1cf8eef0 100644
--- a/discord/guild.py
+++ b/discord/guild.py
@@ -287,7 +287,7 @@ class Guild(Hashable):
self._members[member.id] = member
def _store_thread(self, payload: ThreadPayload, /) -> Thread:
- thread = Thread(guild=self, data=payload)
+ thread = Thread(guild=self, state=self._state, data=payload)
self._threads[thread.id] = thread
return thread
@@ -466,7 +466,7 @@ class Guild(Hashable):
if 'threads' in data:
threads = data['threads']
for thread in threads:
- self._add_thread(Thread(guild=self, data=thread))
+ self._add_thread(Thread(guild=self, state=self._state, data=thread))
@property
def channels(self) -> List[GuildChannel]:
diff --git a/discord/iterators.py b/discord/iterators.py
index 2f272b70..f725d527 100644
--- a/discord/iterators.py
+++ b/discord/iterators.py
@@ -750,4 +750,4 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
def create_thread(self, data: ThreadPayload) -> Thread:
from .threads import Thread
- return Thread(guild=self.guild, data=data)
+ return Thread(guild=self.guild, state=self.guild._state, data=data)
diff --git a/discord/message.py b/discord/message.py
index 825bda4d..b4604b38 100644
--- a/discord/message.py
+++ b/discord/message.py
@@ -1491,7 +1491,7 @@ class Message(Hashable):
auto_archive_duration=auto_archive_duration,
type=ChannelType.public_thread.value,
)
- return Thread(guild=self.guild, data=data) # type: ignore
+ return Thread(guild=self.guild, state=self._state, data=data) # type: ignore
async def reply(self, content: Optional[str] = None, **kwargs) -> Message:
"""|coro|
diff --git a/discord/state.py b/discord/state.py
index 5daa583e..f4a6a664 100644
--- a/discord/state.py
+++ b/discord/state.py
@@ -715,7 +715,7 @@ class ConnectionState:
log.debug('THREAD_CREATE referencing an unknown guild ID: %s. Discarding', guild_id)
return
- thread = Thread(guild=guild, data=data)
+ thread = Thread(guild=guild, state=guild._state, data=data)
has_thread = guild.get_thread(thread.id)
guild._add_thread(thread)
if not has_thread:
@@ -735,7 +735,7 @@ class ConnectionState:
thread._update(data)
self.dispatch('thread_update', old, thread)
else:
- thread = Thread(guild=guild, data=data)
+ thread = Thread(guild=guild, state=guild._state, data=data)
guild._add_thread(thread)
self.dispatch('thread_join', thread)
diff --git a/discord/threads.py b/discord/threads.py
index 85a37018..24eda651 100644
--- a/discord/threads.py
+++ b/discord/threads.py
@@ -139,8 +139,8 @@ class Thread(Messageable, Hashable):
'archive_timestamp',
)
- def __init__(self, *, guild: Guild, data: ThreadPayload):
- self._state: ConnectionState = guild._state
+ def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload):
+ self._state: ConnectionState = state
self.guild = guild
self._members: Dict[int, ThreadMember] = {}
self._from_data(data)