aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRapptz <[email protected]>2021-06-29 03:37:52 -0400
committerRapptz <[email protected]>2021-06-29 03:37:52 -0400
commit2beee8be1408fe044e6c10b665ec0baf8b8425c2 (patch)
tree3be8da6419fa86cf9eb14cdb218a08bd9f4d159e
parent[types] VoiceChannel and StageChannel bitrate/user_limit is not null (diff)
downloaddiscord.py-2beee8be1408fe044e6c10b665ec0baf8b8425c2.tar.xz
discord.py-2beee8be1408fe044e6c10b665ec0baf8b8425c2.zip
Type hint channel.py
-rw-r--r--discord/channel.py498
-rw-r--r--discord/guild.py2
2 files changed, 274 insertions, 226 deletions
diff --git a/discord/channel.py b/discord/channel.py
index 729f7f4d..f07ba638 100644
--- a/discord/channel.py
+++ b/discord/channel.py
@@ -26,7 +26,7 @@ from __future__ import annotations
import time
import asyncio
-from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union, overload
+from typing import Any, Callable, Dict, Iterable, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union, overload
import datetime
import discord.abc
@@ -34,8 +34,9 @@ from .permissions import PermissionOverwrite, Permissions
from .enums import ChannelType, StagePrivacyLevel, try_enum, VoiceRegion, VideoQualityMode
from .mixins import Hashable
from . import utils
+from .utils import MISSING
from .asset import Asset
-from .errors import ClientException, NoMoreItems, InvalidArgument
+from .errors import ClientException, InvalidArgument
from .stage_instance import StageInstance
from .threads import Thread
from .iterators import ArchivedThreadIterator
@@ -55,13 +56,27 @@ if TYPE_CHECKING:
from .role import Role
from .member import Member, VoiceState
from .abc import Snowflake, SnowflakeTime
- from .message import Message
+ from .message import Message, PartialMessage
from .webhook import Webhook
-
-async def _single_delete_strategy(messages):
+ from .state import ConnectionState
+ from .user import ClientUser, User, BaseUser
+ from .guild import Guild, GuildChannel as GuildChannelType
+ from .types.channel import (
+ TextChannel as TextChannelPayload,
+ VoiceChannel as VoiceChannelPayload,
+ StageChannel as StageChannelPayload,
+ DMChannel as DMChannelPayload,
+ CategoryChannel as CategoryChannelPayload,
+ StoreChannel as StoreChannelPayload,
+ GroupDMChannel as GroupChannelPayload,
+ )
+
+
+async def _single_delete_strategy(messages: Iterable[Message]):
for m in messages:
await m.delete()
+
class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
"""Represents a Discord guild text channel.
@@ -114,56 +129,67 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
To check if the channel or the guild of that channel are marked as NSFW, consider :meth:`is_nsfw` instead.
"""
- __slots__ = ('name', 'id', 'guild', 'topic', '_state', 'nsfw',
- 'category_id', 'position', 'slowmode_delay', '_overwrites',
- '_type', 'last_message_id')
-
- def __init__(self, *, state, guild, data):
- self._state = state
- self.id = int(data['id'])
- self._type = data['type']
+ __slots__ = (
+ 'name',
+ 'id',
+ 'guild',
+ 'topic',
+ '_state',
+ 'nsfw',
+ 'category_id',
+ 'position',
+ 'slowmode_delay',
+ '_overwrites',
+ '_type',
+ 'last_message_id',
+ )
+
+ def __init__(self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload):
+ self._state: ConnectionState = state
+ self.id: int = int(data['id'])
+ self._type: int = data['type']
self._update(guild, data)
- def __repr__(self):
+ def __repr__(self) -> str:
attrs = [
('id', self.id),
('name', self.name),
('position', self.position),
('nsfw', self.nsfw),
('news', self.is_news()),
- ('category_id', self.category_id)
+ ('category_id', self.category_id),
]
joined = ' '.join('%s=%r' % t for t in attrs)
return f'<{self.__class__.__name__} {joined}>'
- def _update(self, guild, data):
- self.guild = guild
- self.name = data['name']
- self.category_id = utils._get_as_snowflake(data, 'parent_id')
- self.topic = data.get('topic')
- self.position = data['position']
- self.nsfw = data.get('nsfw', False)
+ def _update(self, guild: Guild, data: TextChannelPayload) -> None:
+ self.guild: Guild = guild
+ self.name: str = data['name']
+ self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id')
+ self.topic: Optional[str] = data.get('topic')
+ self.position: int = data['position']
+ self.nsfw: bool = data.get('nsfw', False)
# Does this need coercion into `int`? No idea yet.
- self.slowmode_delay = data.get('rate_limit_per_user', 0)
- self._type = data.get('type', self._type)
- self.last_message_id = utils._get_as_snowflake(data, 'last_message_id')
+ self.slowmode_delay: int = data.get('rate_limit_per_user', 0)
+ self._type: int = data.get('type', self._type)
+ self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id')
self._fill_overwrites(data)
async def _get_channel(self):
return self
@property
- def type(self):
+ def type(self) -> ChannelType:
""":class:`ChannelType`: The channel's Discord type."""
return try_enum(ChannelType, self._type)
@property
- def _sorting_bucket(self):
+ def _sorting_bucket(self) -> int:
return ChannelType.text.value
@utils.copy_doc(discord.abc.GuildChannel.permissions_for)
- def permissions_for(self, member):
- base = super().permissions_for(member)
+ def permissions_for(self, obj: Union[Member, Role], /) -> Permissions:
+ base = super().permissions_for(obj)
# text channels do not have voice related permissions
denied = Permissions.voice()
@@ -171,28 +197,28 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
return base
@property
- def members(self):
+ def members(self) -> List[Member]:
"""List[:class:`Member`]: Returns all members that can see this channel."""
return [m for m in self.guild.members if self.permissions_for(m).read_messages]
@property
- def threads(self):
+ def threads(self) -> List[Thread]:
"""List[:class:`Thread`]: Returns all the threads that you can see.
.. versionadded:: 2.0
"""
return [thread for thread in self.guild.threads if thread.parent_id == self.id]
- def is_nsfw(self):
+ def is_nsfw(self) -> bool:
""":class:`bool`: Checks if the channel is NSFW."""
return self.nsfw
- def is_news(self):
+ def is_news(self) -> bool:
""":class:`bool`: Checks if the channel is a news channel."""
return self._type == ChannelType.news.value
@property
- def last_message(self):
+ def last_message(self) -> Optional[Message]:
"""Fetches the last message from this channel in cache.
The message might not be valid or point to an existing message.
@@ -289,14 +315,12 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
await self._edit(options, reason=reason)
@utils.copy_doc(discord.abc.GuildChannel.clone)
- async def clone(self, *, name: str = None, reason: str = None) -> TextChannel:
- return await self._clone_impl({
- 'topic': self.topic,
- 'nsfw': self.nsfw,
- 'rate_limit_per_user': self.slowmode_delay
- }, name=name, reason=reason)
-
- async def delete_messages(self, messages):
+ async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> TextChannel:
+ return await self._clone_impl(
+ {'topic': self.topic, 'nsfw': self.nsfw, 'rate_limit_per_user': self.slowmode_delay}, name=name, reason=reason
+ )
+
+ async def delete_messages(self, messages: Iterable[Snowflake]) -> None:
"""|coro|
Deletes a list of messages. This is similar to :meth:`Message.delete`
@@ -332,24 +356,24 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
messages = list(messages)
if len(messages) == 0:
- return # do nothing
+ return # do nothing
if len(messages) == 1:
- message_id = messages[0].id
+ message_id: int = messages[0].id
await self._state.http.delete_message(self.id, message_id)
return
if len(messages) > 100:
raise ClientException('Can only bulk delete messages up to 100 messages')
- message_ids = [m.id for m in messages]
+ message_ids: List[int] = [m.id for m in messages]
await self._state.http.delete_messages(self.id, message_ids)
async def purge(
self,
*,
limit: int = 100,
- check: Callable[[Message], bool] = None,
+ check: Callable[[Message], bool] = MISSING,
before: Optional[SnowflakeTime] = None,
after: Optional[SnowflakeTime] = None,
around: Optional[SnowflakeTime] = None,
@@ -412,54 +436,52 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
The list of messages that were deleted.
"""
- if check is None:
+ if check is MISSING:
check = lambda m: True
iterator = self.history(limit=limit, before=before, after=after, oldest_first=oldest_first, around=around)
- ret = []
+ ret: List[Message] = []
count = 0
minimum_time = int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22
strategy = self.delete_messages if bulk else _single_delete_strategy
- while True:
- try:
- msg = await iterator.next()
- except NoMoreItems:
- # no more messages to poll
- if count >= 2:
- # more than 2 messages -> bulk delete
+ async for message in iterator:
+ if count == 100:
+ to_delete = ret[-100:]
+ await strategy(to_delete)
+ count = 0
+ await asyncio.sleep(1)
+
+ if not check(message):
+ continue
+
+ if message.id < minimum_time:
+ # older than 14 days old
+ if count == 1:
+ await ret[-1].delete()
+ elif count >= 2:
to_delete = ret[-count:]
await strategy(to_delete)
- elif count == 1:
- # delete a single message
- await ret[-1].delete()
- return ret
- else:
- if count == 100:
- # we've reached a full 'queue'
- to_delete = ret[-100:]
- await strategy(to_delete)
- count = 0
- await asyncio.sleep(1)
+ count = 0
+ strategy = _single_delete_strategy
- if check(msg):
- if msg.id < minimum_time:
- # older than 14 days old
- if count == 1:
- await ret[-1].delete()
- elif count >= 2:
- to_delete = ret[-count:]
- await strategy(to_delete)
+ count += 1
+ ret.append(message)
- count = 0
- strategy = _single_delete_strategy
+ # SOme messages remaining to poll
+ if count >= 2:
+ # more than 2 messages -> bulk delete
+ to_delete = ret[-count:]
+ await strategy(to_delete)
+ elif count == 1:
+ # delete a single message
+ await ret[-1].delete()
- count += 1
- ret.append(msg)
+ return ret
- async def webhooks(self):
+ async def webhooks(self) -> List[Webhook]:
"""|coro|
Gets the list of webhooks from this channel.
@@ -478,10 +500,11 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
"""
from .webhook import Webhook
+
data = await self._state.http.channel_webhooks(self.id)
return [Webhook.from_state(d, state=self._state) for d in data]
- async def create_webhook(self, *, name: str, avatar: bytes = None, reason: str = None) -> Webhook:
+ async def create_webhook(self, *, name: str, avatar: Optional[bytes] = None, reason: Optional[str] = None) -> Webhook:
"""|coro|
Creates a webhook for this channel.
@@ -515,8 +538,9 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
"""
from .webhook import Webhook
+
if avatar is not None:
- avatar = utils._bytes_to_base64_data(avatar)
+ avatar = utils._bytes_to_base64_data(avatar) # type: ignore
data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason)
return Webhook.from_state(data, state=self._state)
@@ -563,10 +587,11 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
raise InvalidArgument(f'Expected TextChannel received {destination.__class__.__name__}')
from .webhook import Webhook
+
data = await self._state.http.follow_webhook(self.id, webhook_channel_id=destination.id, reason=reason)
return Webhook._as_follower(data, channel=destination, user=self._state.user)
- def get_partial_message(self, message_id):
+ def get_partial_message(self, message_id: int, /) -> PartialMessage:
"""Creates a :class:`PartialMessage` from the message ID.
This is useful if you want to work with a message and only have its ID without
@@ -586,9 +611,10 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
"""
from .message import PartialMessage
+
return PartialMessage(channel=self, id=message_id)
- def get_thread(self, thread_id: int) -> Optional[Thread]:
+ def get_thread(self, thread_id: int, /) -> Optional[Thread]:
"""Returns a thread with the given ID.
.. versionadded:: 2.0
@@ -724,37 +750,47 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
# TODO: thread members?
return [Thread(guild=self.guild, data=d) for d in data.get('threads', [])]
-class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable):
- __slots__ = ('name', 'id', 'guild', 'bitrate', 'user_limit',
- '_state', 'position', '_overwrites', 'category_id',
- 'rtc_region', 'video_quality_mode')
- def __init__(self, *, state, guild, data):
- self._state = state
- self.id = int(data['id'])
+class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable):
+ __slots__ = (
+ 'name',
+ 'id',
+ 'guild',
+ 'bitrate',
+ 'user_limit',
+ '_state',
+ 'position',
+ '_overwrites',
+ 'category_id',
+ 'rtc_region',
+ 'video_quality_mode',
+ )
+
+ def __init__(self, *, state: ConnectionState, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]):
+ self._state: ConnectionState = state
+ self.id: int = int(data['id'])
self._update(guild, data)
- def _get_voice_client_key(self):
+ def _get_voice_client_key(self) -> Tuple[int, str]:
return self.guild.id, 'guild_id'
- def _get_voice_state_pair(self):
+ def _get_voice_state_pair(self) -> Tuple[int, int]:
return self.guild.id, self.id
- def _update(self, guild, data):
+ def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]) -> None:
self.guild = guild
- self.name = data['name']
- self.rtc_region = data.get('rtc_region')
- if self.rtc_region:
- self.rtc_region = try_enum(VoiceRegion, self.rtc_region)
- self.video_quality_mode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1))
- self.category_id = utils._get_as_snowflake(data, 'parent_id')
- self.position = data['position']
- self.bitrate = data.get('bitrate')
- self.user_limit = data.get('user_limit')
+ self.name: str = data['name']
+ rtc = data.get('rtc_region')
+ self.rtc_region: Optional[VoiceRegion] = try_enum(VoiceRegion, rtc) if rtc is not None else None
+ self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1))
+ self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id')
+ self.position: int = data['position']
+ self.bitrate: int = data.get('bitrate')
+ self.user_limit: int = data.get('user_limit')
self._fill_overwrites(data)
@property
- def _sorting_bucket(self):
+ def _sorting_bucket(self) -> int:
return ChannelType.voice.value
@property
@@ -787,8 +823,8 @@ class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hasha
return {key: value for key, value in self.guild._voice_states.items() if value.channel.id == self.id}
@utils.copy_doc(discord.abc.GuildChannel.permissions_for)
- def permissions_for(self, member: Union[Role, Member], /) -> Permissions:
- base = super().permissions_for(member)
+ def permissions_for(self, obj: Union[Member, Role], /) -> Permissions:
+ base = super().permissions_for(obj)
# voice channels cannot be edited by people who can't connect to them
# It also implicitly denies all other voice perms
@@ -798,6 +834,7 @@ class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hasha
base.value &= ~denied.value
return base
+
class VoiceChannel(VocalGuildChannel):
"""Represents a Discord guild voice channel.
@@ -849,7 +886,7 @@ class VoiceChannel(VocalGuildChannel):
__slots__ = ()
- def __repr__(self):
+ def __repr__(self) -> str:
attrs = [
('id', self.id),
('name', self.name),
@@ -858,28 +895,24 @@ class VoiceChannel(VocalGuildChannel):
('bitrate', self.bitrate),
('video_quality_mode', self.video_quality_mode),
('user_limit', self.user_limit),
- ('category_id', self.category_id)
+ ('category_id', self.category_id),
]
joined = ' '.join('%s=%r' % t for t in attrs)
return f'<{self.__class__.__name__} {joined}>'
@property
- def type(self):
+ def type(self) -> ChannelType:
""":class:`ChannelType`: The channel's Discord type."""
return ChannelType.voice
@utils.copy_doc(discord.abc.GuildChannel.clone)
- async def clone(self, *, name: str = None, reason: str = None) -> VoiceChannel:
- return await self._clone_impl({
- 'bitrate': self.bitrate,
- 'user_limit': self.user_limit
- }, name=name, reason=reason)
+ async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> VoiceChannel:
+ return await self._clone_impl({'bitrate': self.bitrate, 'user_limit': self.user_limit}, name=name, reason=reason)
@overload
async def edit(
self,
*,
- reason: Optional[str] = ...,
name: str = ...,
bitrate: int = ...,
user_limit: int = ...,
@@ -889,6 +922,7 @@ class VoiceChannel(VocalGuildChannel):
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
rtc_region: Optional[VoiceRegion] = ...,
video_quality_mode: VideoQualityMode = ...,
+ reason: Optional[str] = ...,
) -> None:
...
@@ -950,6 +984,7 @@ class VoiceChannel(VocalGuildChannel):
await self._edit(options, reason=reason)
+
class StageChannel(VocalGuildChannel):
"""Represents a Discord guild stage channel.
@@ -1000,9 +1035,10 @@ class StageChannel(VocalGuildChannel):
.. versionadded:: 2.0
"""
+
__slots__ = ('topic',)
- def __repr__(self):
+ def __repr__(self) -> str:
attrs = [
('id', self.id),
('name', self.name),
@@ -1012,12 +1048,12 @@ class StageChannel(VocalGuildChannel):
('bitrate', self.bitrate),
('video_quality_mode', self.video_quality_mode),
('user_limit', self.user_limit),
- ('category_id', self.category_id)
+ ('category_id', self.category_id),
]
joined = ' '.join('%s=%r' % t for t in attrs)
return f'<{self.__class__.__name__} {joined}>'
- def _update(self, guild, data):
+ def _update(self, guild: Guild, data: StageChannelPayload) -> None:
super()._update(guild, data)
self.topic = data.get('topic')
@@ -1032,7 +1068,9 @@ class StageChannel(VocalGuildChannel):
.. versionadded:: 2.0
"""
- return [member for member in self.members if not member.voice.suppress and member.voice.requested_to_speak_at is None]
+ return [
+ member for member in self.members if not member.voice.suppress and member.voice.requested_to_speak_at is None
+ ]
@property
def listeners(self) -> List[Member]:
@@ -1052,12 +1090,12 @@ class StageChannel(VocalGuildChannel):
return [member for member in self.members if self.permissions_for(member) >= required_permissions]
@property
- def type(self):
+ def type(self) -> ChannelType:
""":class:`ChannelType`: The channel's Discord type."""
return ChannelType.stage_voice
@utils.copy_doc(discord.abc.GuildChannel.clone)
- async def clone(self, *, name: str = None, reason: Optional[str] = None) -> StageChannel:
+ async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> StageChannel:
return await self._clone_impl({}, name=name, reason=reason)
@property
@@ -1068,7 +1106,7 @@ class StageChannel(VocalGuildChannel):
"""
return utils.get(self.guild.stage_instances, channel_id=self.id)
- async def create_instance(self, *, topic: str, privacy_level: StagePrivacyLevel = utils.MISSING) -> StageInstance:
+ async def create_instance(self, *, topic: str, privacy_level: StagePrivacyLevel = MISSING) -> StageInstance:
"""|coro|
Create a stage instance.
@@ -1100,12 +1138,9 @@ class StageChannel(VocalGuildChannel):
The newly created stage instance.
"""
- payload = {
- 'channel_id': self.id,
- 'topic': topic
- }
+ payload: Dict[str, Any] = {'channel_id': self.id, 'topic': topic}
- if privacy_level is not utils.MISSING:
+ if privacy_level is not MISSING:
if not isinstance(privacy_level, StagePrivacyLevel):
raise InvalidArgument('privacy_level field must be of type PrivacyLevel')
@@ -1140,7 +1175,6 @@ class StageChannel(VocalGuildChannel):
async def edit(
self,
*,
- reason: Optional[str] = ...,
name: str = ...,
topic: Optional[str] = ...,
position: int = ...,
@@ -1149,6 +1183,7 @@ class StageChannel(VocalGuildChannel):
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
rtc_region: Optional[VoiceRegion] = ...,
video_quality_mode: VideoQualityMode = ...,
+ reason: Optional[str] = ...,
) -> None:
...
@@ -1203,6 +1238,8 @@ class StageChannel(VocalGuildChannel):
"""
await self._edit(options, reason=reason)
+
+
class CategoryChannel(discord.abc.GuildChannel, Hashable):
"""Represents a Discord channel category.
@@ -1247,50 +1284,48 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
__slots__ = ('name', 'id', 'guild', 'nsfw', '_state', 'position', '_overwrites', 'category_id')
- def __init__(self, *, state, guild, data):
- self._state = state
- self.id = int(data['id'])
+ def __init__(self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload):
+ self._state: ConnectionState = state
+ self.id: int = int(data['id'])
self._update(guild, data)
- def __repr__(self):
+ def __repr__(self) -> str:
return f'<CategoryChannel id={self.id} name={self.name!r} position={self.position} nsfw={self.nsfw}>'
- def _update(self, guild, data):
- self.guild = guild
- self.name = data['name']
- self.category_id = utils._get_as_snowflake(data, 'parent_id')
- self.nsfw = data.get('nsfw', False)
- self.position = data['position']
+ def _update(self, guild: Guild, data: CategoryChannelPayload) -> None:
+ self.guild: Guild = guild
+ self.name: str = data['name']
+ self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id')
+ self.nsfw: bool = data.get('nsfw', False)
+ self.position: int = data['position']
self._fill_overwrites(data)
@property
- def _sorting_bucket(self):
+ def _sorting_bucket(self) -> int:
return ChannelType.category.value
@property
- def type(self):
+ def type(self) -> ChannelType:
""":class:`ChannelType`: The channel's Discord type."""
return ChannelType.category
- def is_nsfw(self):
+ def is_nsfw(self) -> bool:
""":class:`bool`: Checks if the category is NSFW."""
return self.nsfw
@utils.copy_doc(discord.abc.GuildChannel.clone)
- async def clone(self, *, name: str = None, reason: Optional[str] = None) -> CategoryChannel:
- return await self._clone_impl({
- 'nsfw': self.nsfw
- }, name=name, reason=reason)
+ async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> CategoryChannel:
+ return await self._clone_impl({'nsfw': self.nsfw}, name=name, reason=reason)
@overload
async def edit(
self,
*,
- reason: Optional[str] = ...,
name: str = ...,
position: int = ...,
nsfw: bool = ...,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
+ reason: Optional[str] = ...,
) -> None:
...
@@ -1341,11 +1376,12 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
await super().move(**kwargs)
@property
- def channels(self):
+ def channels(self) -> List[GuildChannelType]:
"""List[:class:`abc.GuildChannel`]: Returns the channels that are under this category.
These are sorted by the official Discord UI, which places voice channels below the text channels.
"""
+
def comparator(channel):
return (not isinstance(channel, TextChannel), channel.position)
@@ -1354,36 +1390,30 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
return ret
@property
- def text_channels(self):
+ def text_channels(self) -> List[TextChannel]:
"""List[:class:`TextChannel`]: Returns the text channels that are under this category."""
- ret = [c for c in self.guild.channels
- if c.category_id == self.id
- and isinstance(c, TextChannel)]
+ ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, TextChannel)]
ret.sort(key=lambda c: (c.position, c.id))
return ret
@property
- def voice_channels(self):
+ def voice_channels(self) -> List[VoiceChannel]:
"""List[:class:`VoiceChannel`]: Returns the voice channels that are under this category."""
- ret = [c for c in self.guild.channels
- if c.category_id == self.id
- and isinstance(c, VoiceChannel)]
+ ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, VoiceChannel)]
ret.sort(key=lambda c: (c.position, c.id))
return ret
@property
- def stage_channels(self):
+ def stage_channels(self) -> List[StageChannel]:
"""List[:class:`StageChannel`]: Returns the stage channels that are under this category.
.. versionadded:: 1.7
"""
- ret = [c for c in self.guild.channels
- if c.category_id == self.id
- and isinstance(c, StageChannel)]
+ ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, StageChannel)]
ret.sort(key=lambda c: (c.position, c.id))
return ret
- async def create_text_channel(self, name, **options):
+ async def create_text_channel(self, name: str, **options: Any) -> TextChannel:
"""|coro|
A shortcut method to :meth:`Guild.create_text_channel` to create a :class:`TextChannel` in the category.
@@ -1395,7 +1425,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
"""
return await self.guild.create_text_channel(name, category=self, **options)
- async def create_voice_channel(self, name, **options):
+ async def create_voice_channel(self, name: str, **options: Any) -> VoiceChannel:
"""|coro|
A shortcut method to :meth:`Guild.create_voice_channel` to create a :class:`VoiceChannel` in the category.
@@ -1407,7 +1437,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
"""
return await self.guild.create_voice_channel(name, category=self, **options)
- async def create_stage_channel(self, name, **options):
+ async def create_stage_channel(self, name: str, **options: Any) -> StageChannel:
"""|coro|
A shortcut method to :meth:`Guild.create_stage_channel` to create a :class:`StageChannel` in the category.
@@ -1421,6 +1451,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
"""
return await self.guild.create_stage_channel(name, category=self, **options)
+
class StoreChannel(discord.abc.GuildChannel, Hashable):
"""Represents a Discord guild store channel.
@@ -1462,52 +1493,59 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
To check if the channel or the guild of that channel are marked as NSFW, consider :meth:`is_nsfw` instead.
"""
- __slots__ = ('name', 'id', 'guild', '_state', 'nsfw',
- 'category_id', 'position', '_overwrites',)
- def __init__(self, *, state, guild, data):
- self._state = state
- self.id = int(data['id'])
+ __slots__ = (
+ 'name',
+ 'id',
+ 'guild',
+ '_state',
+ 'nsfw',
+ 'category_id',
+ 'position',
+ '_overwrites',
+ )
+
+ def __init__(self, *, state: ConnectionState, guild: Guild, data: StoreChannelPayload):
+ self._state: ConnectionState = state
+ self.id: int = int(data['id'])
self._update(guild, data)
- def __repr__(self):
+ def __repr__(self) -> str:
return f'<StoreChannel id={self.id} name={self.name!r} position={self.position} nsfw={self.nsfw}>'
- def _update(self, guild, data):
- self.guild = guild
- self.name = data['name']
- self.category_id = utils._get_as_snowflake(data, 'parent_id')
- self.position = data['position']
- self.nsfw = data.get('nsfw', False)
+ def _update(self, guild: Guild, data: StoreChannelPayload) -> None:
+ self.guild: Guild = guild
+ self.name: str = data['name']
+ self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id')
+ self.position: int = data['position']
+ self.nsfw: bool = data.get('nsfw', False)
self._fill_overwrites(data)
@property
- def _sorting_bucket(self):
+ def _sorting_bucket(self) -> int:
return ChannelType.text.value
@property
- def type(self):
+ def type(self) -> ChannelType:
""":class:`ChannelType`: The channel's Discord type."""
return ChannelType.store
@utils.copy_doc(discord.abc.GuildChannel.permissions_for)
- def permissions_for(self, member):
- base = super().permissions_for(member)
+ def permissions_for(self, obj: Union[Member, Role], /) -> Permissions:
+ base = super().permissions_for(obj)
# store channels do not have voice related permissions
denied = Permissions.voice()
base.value &= ~denied.value
return base
- def is_nsfw(self):
+ def is_nsfw(self) -> bool:
""":class:`bool`: Checks if the channel is NSFW."""
return self.nsfw
@utils.copy_doc(discord.abc.GuildChannel.clone)
- async def clone(self, *, name: str = None, reason: Optional[str] = None) -> StoreChannel:
- return await self._clone_impl({
- 'nsfw': self.nsfw
- }, name=name, reason=reason)
+ async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> StoreChannel:
+ return await self._clone_impl({'nsfw': self.nsfw}, name=name, reason=reason)
@overload
async def edit(
@@ -1519,7 +1557,7 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
sync_permissions: bool = ...,
category: Optional[CategoryChannel],
reason: Optional[str],
- overwrites: Dict[Union[Role, Member], PermissionOverwrite]
+ overwrites: Dict[Union[Role, Member], PermissionOverwrite],
) -> None:
...
@@ -1569,6 +1607,10 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
"""
await self._edit(options, reason=reason)
+
+DMC = TypeVar('DMC', bound='DMChannel')
+
+
class DMChannel(discord.abc.Messageable, Hashable):
"""Represents a Discord direct message channel.
@@ -1604,43 +1646,43 @@ class DMChannel(discord.abc.Messageable, Hashable):
__slots__ = ('id', 'recipient', 'me', '_state')
- def __init__(self, *, me, state, data):
- self._state = state
- self.recipient = state.store_user(data['recipients'][0])
- self.me = me
- self.id = int(data['id'])
+ def __init__(self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload):
+ self._state: ConnectionState = state
+ self.recipient: Optional[User] = state.store_user(data['recipients'][0])
+ self.me: ClientUser = me
+ self.id: int = int(data['id'])
async def _get_channel(self):
return self
- def __str__(self):
+ def __str__(self) -> str:
if self.recipient:
return f'Direct Message with {self.recipient}'
return 'Direct Message with Unknown User'
- def __repr__(self):
+ def __repr__(self) -> str:
return f'<DMChannel id={self.id} recipient={self.recipient!r}>'
@classmethod
- def _from_message(cls, state, channel_id):
- self = cls.__new__(cls)
+ def _from_message(cls: Type[DMC], state: ConnectionState, channel_id: int) -> DMC:
+ self: DMC = cls.__new__(cls)
self._state = state
self.id = channel_id
self.recipient = None
- self.me = state.user
+ self.me = state.user # type: ignore
return self
@property
- def type(self):
+ def type(self) -> ChannelType:
""":class:`ChannelType`: The channel's Discord type."""
return ChannelType.private
@property
- def created_at(self):
+ def created_at(self) -> datetime.datetime:
""":class:`datetime.datetime`: Returns the direct message channel's creation time in UTC."""
return utils.snowflake_time(self.id)
- def permissions_for(self, user=None):
+ def permissions_for(self, obj: Any = None, /) -> Permissions:
"""Handles permission resolution for a :class:`User`.
This function is there for compatibility with other channel types.
@@ -1654,9 +1696,9 @@ class DMChannel(discord.abc.Messageable, Hashable):
Parameters
-----------
- user: :class:`User`
+ obj: :class:`User`
The user to check permissions for. This parameter is ignored
- but kept for compatibility.
+ but kept for compatibility with other ``permissions_for`` methods.
Returns
--------
@@ -1670,7 +1712,7 @@ class DMChannel(discord.abc.Messageable, Hashable):
base.manage_messages = False
return base
- def get_partial_message(self, message_id):
+ def get_partial_message(self, message_id: int, /) -> PartialMessage:
"""Creates a :class:`PartialMessage` from the message ID.
This is useful if you want to work with a message and only have its ID without
@@ -1690,8 +1732,10 @@ class DMChannel(discord.abc.Messageable, Hashable):
"""
from .message import PartialMessage
+
return PartialMessage(channel=self, id=message_id)
+
class GroupChannel(discord.abc.Messageable, Hashable):
"""Represents a Discord group channel.
@@ -1721,39 +1765,40 @@ class GroupChannel(discord.abc.Messageable, Hashable):
The user presenting yourself.
id: :class:`int`
The group channel ID.
- owner: :class:`User`
+ owner: Optional[:class:`User`]
The user that owns the group channel.
+ owner_id: :class:`int`
+ The owner ID that owns the group channel.
+
+ .. versionadded:: 2.0
name: Optional[:class:`str`]
The group channel's name if provided.
"""
- __slots__ = ('id', 'recipients', 'owner', '_icon', 'name', 'me', '_state')
+ __slots__ = ('id', 'recipients', 'owner_id', 'owner', '_icon', 'name', 'me', '_state')
- def __init__(self, *, me, state, data):
- self._state = state
- self.id = int(data['id'])
- self.me = me
+ def __init__(self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload):
+ self._state: ConnectionState = state
+ self.id: int = int(data['id'])
+ self.me: ClientUser = me
self._update_group(data)
- def _update_group(self, data):
- owner_id = utils._get_as_snowflake(data, 'owner_id')
- self._icon = data.get('icon')
- self.name = data.get('name')
-
- try:
- self.recipients = [self._state.store_user(u) for u in data['recipients']]
- except KeyError:
- pass
+ def _update_group(self, data: GroupChannelPayload) -> None:
+ self.owner_id: Optional[int] = utils._get_as_snowflake(data, 'owner_id')
+ self._icon: Optional[str] = data.get('icon')
+ self.name: Optional[str] = data.get('name')
+ self.recipients: List[User] = [self._state.store_user(u) for u in data.get('recipients', [])]
- if owner_id == self.me.id:
+ self.owner: Optional[BaseUser]
+ if self.owner_id == self.me.id:
self.owner = self.me
else:
- self.owner = utils.find(lambda u: u.id == owner_id, self.recipients)
+ self.owner = utils.find(lambda u: u.id == self.owner_id, self.recipients)
async def _get_channel(self):
return self
- def __str__(self):
+ def __str__(self) -> str:
if self.name:
return self.name
@@ -1762,27 +1807,27 @@ class GroupChannel(discord.abc.Messageable, Hashable):
return ', '.join(map(lambda x: x.name, self.recipients))
- def __repr__(self):
+ def __repr__(self) -> str:
return f'<GroupChannel id={self.id} name={self.name!r}>'
@property
- def type(self):
+ def type(self) -> ChannelType:
""":class:`ChannelType`: The channel's Discord type."""
return ChannelType.group
@property
- def icon(self):
+ def icon(self) -> Optional[Asset]:
"""Optional[:class:`Asset`]: Returns the channel's icon asset if available."""
if self._icon is None:
return None
return Asset._from_icon(self._state, self.id, self._icon, path='channel')
@property
- def created_at(self):
+ def created_at(self) -> datetime.datetime:
""":class:`datetime.datetime`: Returns the channel's creation time in UTC."""
return utils.snowflake_time(self.id)
- def permissions_for(self, user):
+ def permissions_for(self, obj: Snowflake, /) -> Permissions:
"""Handles permission resolution for a :class:`User`.
This function is there for compatibility with other channel types.
@@ -1798,7 +1843,7 @@ class GroupChannel(discord.abc.Messageable, Hashable):
Parameters
-----------
- user: :class:`User`
+ obj: :class:`~discord.abc.Snowflake`
The user to check permissions for.
Returns
@@ -1813,12 +1858,12 @@ class GroupChannel(discord.abc.Messageable, Hashable):
base.manage_messages = False
base.mention_everyone = True
- if user.id == self.owner.id:
+ if obj.id == self.owner_id:
base.kick_members = True
return base
- async def leave(self):
+ async def leave(self) -> None:
"""|coro|
Leave the group.
@@ -1833,11 +1878,13 @@ class GroupChannel(discord.abc.Messageable, Hashable):
await self._state.http.leave_group(self.id)
+
def _coerce_channel_type(value: Union[ChannelType, int]) -> ChannelType:
if isinstance(value, ChannelType):
return value
return try_enum(ChannelType, value)
+
def _guild_channel_factory(channel_type: Union[ChannelType, int]):
value = _coerce_channel_type(channel_type)
if value is ChannelType.text:
@@ -1855,6 +1902,7 @@ def _guild_channel_factory(channel_type: Union[ChannelType, int]):
else:
return None, value
+
def _channel_factory(channel_type: Union[ChannelType, int]):
cls, value = _guild_channel_factory(channel_type)
if value is ChannelType.private:
diff --git a/discord/guild.py b/discord/guild.py
index 024aa5d8..b43b12df 100644
--- a/discord/guild.py
+++ b/discord/guild.py
@@ -461,7 +461,7 @@ class Guild(Hashable):
for c in channels:
factory, ch_type = _guild_channel_factory(c['type'])
if factory:
- self._add_channel(factory(guild=self, data=c, state=self._state))
+ self._add_channel(factory(guild=self, data=c, state=self._state)) # type: ignore
if 'threads' in data:
threads = data['threads']