aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--discord/channel.py7
-rw-r--r--discord/guild.py11
2 files changed, 13 insertions, 5 deletions
diff --git a/discord/channel.py b/discord/channel.py
index ffb11ffc..f6f7d6b2 100644
--- a/discord/channel.py
+++ b/discord/channel.py
@@ -2038,3 +2038,10 @@ def _threaded_channel_factory(channel_type: int):
if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread):
return Thread, value
return cls, value
+
+
+def _threaded_guild_channel_factory(channel_type: int):
+ cls, value = _guild_channel_factory(channel_type)
+ if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread):
+ return Thread, value
+ return cls, value
diff --git a/discord/guild.py b/discord/guild.py
index c61ed7f0..ebdc8764 100644
--- a/discord/guild.py
+++ b/discord/guild.py
@@ -52,6 +52,7 @@ from .colour import Colour
from .errors import InvalidArgument, ClientException
from .channel import *
from .channel import _guild_channel_factory
+from .channel import _threaded_guild_channel_factory
from .enums import (
AuditLogAction,
VideoQualityMode,
@@ -1703,14 +1704,14 @@ class Guild(Hashable):
data: BanPayload = await self._state.http.get_ban(user.id, self.id)
return BanEntry(user=User(state=self._state, data=data['user']), reason=data['reason'])
- async def fetch_channel(self, channel_id: int, /) -> GuildChannel:
+ async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, Thread]:
"""|coro|
- Retrieves a :class:`.abc.GuildChannel` with the specified ID.
+ Retrieves a :class:`.abc.GuildChannel` or :class:`.Thread` with the specified ID.
.. note::
- This method is an API call. For general usage, consider :meth:`get_channel` instead.
+ This method is an API call. For general usage, consider :meth:`get_channel_or_thread` instead.
.. versionadded:: 2.0
@@ -1729,12 +1730,12 @@ class Guild(Hashable):
Returns
--------
- :class:`.abc.GuildChannel`
+ Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
The channel from the ID.
"""
data = await self._state.http.get_channel(channel_id)
- factory, ch_type = _guild_channel_factory(data['type'])
+ factory, ch_type = _threaded_guild_channel_factory(data['type'])
if factory is None:
raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data))