aboutsummaryrefslogtreecommitdiff
path: root/discord/abc.py
diff options
context:
space:
mode:
Diffstat (limited to 'discord/abc.py')
-rw-r--r--discord/abc.py247
1 files changed, 171 insertions, 76 deletions
diff --git a/discord/abc.py b/discord/abc.py
index b94f9a71..df43dae5 100644
--- a/discord/abc.py
+++ b/discord/abc.py
@@ -26,7 +26,21 @@ from __future__ import annotations
import copy
import asyncio
-from typing import Any, Dict, List, Mapping, Optional, TYPE_CHECKING, Protocol, Type, TypeVar, Union, overload, runtime_checkable
+from typing import (
+ Any,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ TYPE_CHECKING,
+ Protocol,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ overload,
+ runtime_checkable,
+)
from .iterators import HistoryIterator
from .context_managers import Typing
@@ -62,16 +76,24 @@ if TYPE_CHECKING:
from .channel import CategoryChannel
from .embeds import Embed
from .message import Message, MessageReference
+ from .channel import TextChannel, DMChannel, GroupChannel
+ from .threads import Thread
from .enums import InviteTarget
from .ui.view import View
+ from .types.channel import (
+ PermissionOverwrite as PermissionOverwritePayload,
+ GuildChannel as GuildChannelPayload,
+ OverwriteType,
+ )
+ MessageableChannel = Union[TextChannel, Thread, DMChannel, GroupChannel]
SnowflakeTime = Union["Snowflake", datetime]
MISSING = utils.MISSING
class _Undefined:
- def __repr__(self):
+ def __repr__(self) -> str:
return 'see-below'
@@ -102,6 +124,7 @@ class Snowflake(Protocol):
""":class:`datetime.datetime`: Returns the model's creation time as an aware datetime in UTC."""
raise NotImplementedError
+
@runtime_checkable
class User(Snowflake, Protocol):
"""An ABC that details the common operations on a Discord user.
@@ -172,13 +195,13 @@ class _Overwrites:
ROLE = 0
MEMBER = 1
- def __init__(self, **kwargs):
- self.id = kwargs.pop('id')
- self.allow = int(kwargs.pop('allow', 0))
- self.deny = int(kwargs.pop('deny', 0))
- self.type = kwargs.pop('type')
+ def __init__(self, data: PermissionOverwritePayload):
+ self.id: int = int(data.pop('id'))
+ self.allow: int = int(data.pop('allow', 0))
+ self.deny: int = int(data.pop('deny', 0))
+ self.type: OverwriteType = data.pop('type')
- def _asdict(self):
+ def _asdict(self) -> PermissionOverwritePayload:
return {
'id': self.id,
'allow': str(self.allow),
@@ -208,11 +231,6 @@ class GuildChannel:
This ABC must also implement :class:`~discord.abc.Snowflake`.
- Note
- ----
- This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass`
- checks.
-
Attributes
-----------
name: :class:`str`
@@ -230,7 +248,10 @@ class GuildChannel:
name: str
guild: Guild
type: ChannelType
+ position: int
+ category_id: Optional[int]
_state: ConnectionState
+ _overwrites: List[_Overwrites]
if TYPE_CHECKING:
@@ -254,13 +275,13 @@ class GuildChannel:
lock_permissions: bool = False,
*,
reason: Optional[str],
- ):
+ ) -> None:
if position < 0:
raise InvalidArgument('Channel position cannot be less than 0.')
http = self._state.http
bucket = self._sorting_bucket
- channels = [c for c in self.guild.channels if c._sorting_bucket == bucket]
+ channels: List[GuildChannel] = [c for c in self.guild.channels if c._sorting_bucket == bucket]
channels.sort(key=lambda c: c.position)
@@ -277,7 +298,7 @@ class GuildChannel:
payload = []
for index, c in enumerate(channels):
- d = {'id': c.id, 'position': index}
+ d: Dict[str, Any] = {'id': c.id, 'position': index}
if parent_id is not _undefined and c.id == self.id:
d.update(parent_id=parent_id, lock_permissions=lock_permissions)
payload.append(d)
@@ -287,7 +308,7 @@ class GuildChannel:
if parent_id is not _undefined:
self.category_id = int(parent_id) if parent_id else None
- async def _edit(self, options, reason):
+ async def _edit(self, options: Dict[str, Any], reason: Optional[str]):
try:
parent = options.pop('category')
except KeyError:
@@ -322,13 +343,15 @@ class GuildChannel:
if parent_id is not _undefined:
if lock_permissions:
category = self.guild.get_channel(parent_id)
- options['permission_overwrites'] = [c._asdict() for c in category._overwrites]
+ if category:
+ options['permission_overwrites'] = [c._asdict() for c in category._overwrites]
options['parent_id'] = parent_id
elif lock_permissions and self.category_id is not None:
# if we're syncing permissions on a pre-existing channel category without changing it
# we need to update the permissions to point to the pre-existing category
category = self.guild.get_channel(self.category_id)
- options['permission_overwrites'] = [c._asdict() for c in category._overwrites]
+ if category:
+ options['permission_overwrites'] = [c._asdict() for c in category._overwrites]
else:
await self._move(position, parent_id=parent_id, lock_permissions=lock_permissions, reason=reason)
@@ -367,19 +390,19 @@ class GuildChannel:
data = await self._state.http.edit_channel(self.id, reason=reason, **options)
self._update(self.guild, data)
- def _fill_overwrites(self, data):
+ def _fill_overwrites(self, data: GuildChannelPayload) -> None:
self._overwrites = []
everyone_index = 0
everyone_id = self.guild.id
for index, overridden in enumerate(data.get('permission_overwrites', [])):
- overridden_id = int(overridden.pop('id'))
- self._overwrites.append(_Overwrites(id=overridden_id, **overridden))
+ overwrite = _Overwrites(overridden)
+ self._overwrites.append(overwrite)
if overridden['type'] == _Overwrites.MEMBER:
continue
- if overridden_id == everyone_id:
+ if overwrite.id == everyone_id:
# the @everyone role is not guaranteed to be the first one
# in the list of permission overwrites, however the permission
# resolution code kind of requires that it is the first one in
@@ -488,7 +511,7 @@ class GuildChannel:
If there is no category then this is ``None``.
"""
- return self.guild.get_channel(self.category_id)
+ return self.guild.get_channel(self.category_id) # type: ignore
@property
def permissions_synced(self) -> bool:
@@ -499,6 +522,9 @@ class GuildChannel:
.. versionadded:: 1.3
"""
+ if self.category_id is None:
+ return False
+
category = self.guild.get_channel(self.category_id)
return bool(category and category.overwrites == self.overwrites)
@@ -679,14 +705,7 @@ class GuildChannel:
) -> None:
...
- async def set_permissions(
- self,
- target,
- *,
- overwrite=_undefined,
- reason=None,
- **permissions
- ):
+ async def set_permissions(self, target, *, overwrite=_undefined, reason=None, **permissions):
r"""|coro|
Sets the channel specific permission overwrites for a target in the
@@ -801,7 +820,7 @@ class GuildChannel:
obj = cls(state=self._state, guild=self.guild, data=data)
# temporarily add it to the cache
- self.guild._channels[obj.id] = obj
+ self.guild._channels[obj.id] = obj # type: ignore
return obj
async def clone(self: GCH, *, name: Optional[str] = None, reason: Optional[str] = None) -> GCH:
@@ -956,6 +975,7 @@ class GuildChannel:
bucket = self._sorting_bucket
parent_id = kwargs.get('category', MISSING)
# fmt: off
+ channels: List[GuildChannel]
if parent_id not in (MISSING, None):
parent_id = parent_id.id
channels = [
@@ -1017,7 +1037,7 @@ class GuildChannel:
unique: bool = True,
target_type: Optional[InviteTarget] = None,
target_user: Optional[User] = None,
- target_application_id: Optional[int] = None
+ target_application_id: Optional[int] = None,
) -> Invite:
"""|coro|
@@ -1045,9 +1065,9 @@ class GuildChannel:
The reason for creating this invite. Shows up on the audit log.
target_type: Optional[:class:`.InviteTarget`]
The type of target for the voice channel invite, if any.
-
+
.. versionadded:: 2.0
-
+
target_user: Optional[:class:`User`]
The user whose stream to display for this invite, required if `target_type` is `TargetType.stream`. The user must be streaming in the channel.
@@ -1081,7 +1101,7 @@ class GuildChannel:
unique=unique,
target_type=target_type.value if target_type else None,
target_user_id=target_user.id if target_user else None,
- target_application_id=target_application_id
+ target_application_id=target_application_id,
)
return Invite.from_incomplete(data=data, state=self._state)
@@ -1111,7 +1131,7 @@ class GuildChannel:
return [Invite(state=state, data=invite, channel=self, guild=guild) for invite in data]
-class Messageable(Protocol):
+class Messageable:
"""An ABC that details the common operations on a model that can send messages.
The following implement this ABC:
@@ -1122,28 +1142,57 @@ class Messageable(Protocol):
- :class:`~discord.User`
- :class:`~discord.Member`
- :class:`~discord.ext.commands.Context`
-
-
- Note
- ----
- This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass`
- checks.
"""
__slots__ = ()
+ _state: ConnectionState
- async def _get_channel(self):
+ async def _get_channel(self) -> MessageableChannel:
raise NotImplementedError
@overload
async def send(
self,
- content: Optional[str] =...,
+ content: Optional[str] = ...,
*,
tts: bool = ...,
embed: Embed = ...,
file: File = ...,
- delete_after: int = ...,
+ delete_after: float = ...,
+ nonce: Union[str, int] = ...,
+ allowed_mentions: AllowedMentions = ...,
+ reference: Union[Message, MessageReference] = ...,
+ mention_author: bool = ...,
+ view: View = ...,
+ ) -> Message:
+ ...
+
+ @overload
+ async def send(
+ self,
+ content: Optional[str] = ...,
+ *,
+ tts: bool = ...,
+ embed: Embed = ...,
+ files: List[File] = ...,
+ delete_after: float = ...,
+ nonce: Union[str, int] = ...,
+ allowed_mentions: AllowedMentions = ...,
+ reference: Union[Message, MessageReference] = ...,
+ mention_author: bool = ...,
+ view: View = ...,
+ ) -> Message:
+ ...
+
+ @overload
+ async def send(
+ self,
+ content: Optional[str] = ...,
+ *,
+ tts: bool = ...,
+ embeds: List[Embed] = ...,
+ file: File = ...,
+ delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference] = ...,
@@ -1160,7 +1209,7 @@ class Messageable(Protocol):
tts: bool = ...,
embeds: List[Embed] = ...,
files: List[File] = ...,
- delete_after: int = ...,
+ delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference] = ...,
@@ -1169,10 +1218,22 @@ class Messageable(Protocol):
) -> Message:
...
- async def send(self, content=None, *, tts=False, embed=None, embeds=None,
- file=None, files=None, delete_after=None,
- nonce=None, allowed_mentions=None, reference=None,
- mention_author=None, view=None):
+ async def send(
+ self,
+ content=None,
+ *,
+ tts=None,
+ embed=None,
+ embeds=None,
+ file=None,
+ files=None,
+ delete_after=None,
+ nonce=None,
+ allowed_mentions=None,
+ reference=None,
+ mention_author=None,
+ view=None,
+ ):
"""|coro|
Sends a message to the destination with the content given.
@@ -1185,7 +1246,7 @@ class Messageable(Protocol):
single :class:`~discord.File` object. To upload multiple files, the ``files``
parameter should be used with a :class:`list` of :class:`~discord.File` objects.
**Specifying both parameters will lead to an exception**.
-
+
To upload a single embed, the ``embed`` parameter should be used with a
single :class:`~discord.Embed` object. To upload multiple embeds, the ``embeds``
parameter should be used with a :class:`list` of :class:`~discord.Embed` objects.
@@ -1193,7 +1254,7 @@ class Messageable(Protocol):
Parameters
------------
- content: :class:`str`
+ content: Optional[:class:`str`]
The content of the message to send.
tts: :class:`bool`
Indicates if the message should be sent using text-to-speech.
@@ -1261,13 +1322,13 @@ class Messageable(Protocol):
channel = await self._get_channel()
state = self._state
content = str(content) if content is not None else None
-
+
if embed is not None and embeds is not None:
raise InvalidArgument('cannot pass both embed and embeds parameter to send()')
-
+
if embed is not None:
embed = embed.to_dict()
-
+
elif embeds is not None:
if len(embeds) > 10:
raise InvalidArgument('embeds parameter must be a list of up to 10 elements')
@@ -1307,9 +1368,18 @@ class Messageable(Protocol):
raise InvalidArgument('file parameter must be File')
try:
- data = await state.http.send_files(channel.id, files=[file], allowed_mentions=allowed_mentions,
- content=content, tts=tts, embed=embed, embeds=embeds,
- nonce=nonce, message_reference=reference, components=components)
+ data = await state.http.send_files(
+ channel.id,
+ files=[file],
+ allowed_mentions=allowed_mentions,
+ content=content,
+ tts=tts,
+ embed=embed,
+ embeds=embeds,
+ nonce=nonce,
+ message_reference=reference,
+ components=components,
+ )
finally:
file.close()
@@ -1320,17 +1390,33 @@ class Messageable(Protocol):
raise InvalidArgument('files parameter must be a list of File')
try:
- data = await state.http.send_files(channel.id, files=files, content=content, tts=tts,
- embed=embed, embeds=embeds, nonce=nonce,
- allowed_mentions=allowed_mentions, message_reference=reference,
- components=components)
+ data = await state.http.send_files(
+ channel.id,
+ files=files,
+ content=content,
+ tts=tts,
+ embed=embed,
+ embeds=embeds,
+ nonce=nonce,
+ allowed_mentions=allowed_mentions,
+ message_reference=reference,
+ components=components,
+ )
finally:
for f in files:
f.close()
else:
- data = await state.http.send_message(channel.id, content, tts=tts, embed=embed,
- embeds=embeds, nonce=nonce, allowed_mentions=allowed_mentions,
- message_reference=reference, components=components)
+ data = await state.http.send_message(
+ channel.id,
+ content,
+ tts=tts,
+ embed=embed,
+ embeds=embeds,
+ nonce=nonce,
+ allowed_mentions=allowed_mentions,
+ message_reference=reference,
+ components=components,
+ )
ret = state.create_message(channel=channel, data=data)
if view:
@@ -1340,7 +1426,7 @@ class Messageable(Protocol):
await ret.delete(delay=delete_after)
return ret
- async def trigger_typing(self):
+ async def trigger_typing(self) -> None:
"""|coro|
Triggers a *typing* indicator to the destination.
@@ -1351,7 +1437,7 @@ class Messageable(Protocol):
channel = await self._get_channel()
await self._state.http.send_typing(channel.id)
- def typing(self):
+ def typing(self) -> Typing:
"""Returns a context manager that allows you to type for an indefinite period of time.
This is useful for denoting long computations in your bot.
@@ -1362,8 +1448,8 @@ class Messageable(Protocol):
This means that both ``with`` and ``async with`` work with this.
Example Usage: ::
- async with channel.typing():
- # simulate something heavy
+ async with channel.typing():
+ # simulate something heavy
await asyncio.sleep(10)
await channel.send('done!')
@@ -1371,7 +1457,7 @@ class Messageable(Protocol):
"""
return Typing(self)
- async def fetch_message(self, id):
+ async def fetch_message(self, id: int, /) -> Message:
"""|coro|
Retrieves a single :class:`~discord.Message` from the destination.
@@ -1400,7 +1486,7 @@ class Messageable(Protocol):
data = await self._state.http.get_message(channel.id, id)
return self._state.create_message(channel=channel, data=data)
- async def pins(self):
+ async def pins(self) -> List[Message]:
"""|coro|
Retrieves all messages that are currently pinned in the channel.
@@ -1427,7 +1513,15 @@ class Messageable(Protocol):
data = await state.http.pins_from(channel.id)
return [state.create_message(channel=channel, data=m) for m in data]
- def history(self, *, limit=100, before=None, after=None, around=None, oldest_first=None):
+ def history(
+ self,
+ *,
+ limit: Optional[int] = 100,
+ before: Optional[SnowflakeTime] = None,
+ after: Optional[SnowflakeTime] = None,
+ around: Optional[SnowflakeTime] = None,
+ oldest_first: Optional[bool] = None,
+ ) -> HistoryIterator:
"""Returns an :class:`~discord.AsyncIterator` that enables receiving the destination's message history.
You must have :attr:`~discord.Permissions.read_message_history` permissions to use this.
@@ -1504,11 +1598,12 @@ class Connectable(Protocol):
"""
__slots__ = ()
+ _state: ConnectionState
- def _get_voice_client_key(self):
+ def _get_voice_client_key(self) -> Tuple[int, str]:
raise NotImplementedError
- def _get_voice_state_pair(self):
+ def _get_voice_state_pair(self) -> Tuple[int, int]:
raise NotImplementedError
async def connect(self, *, timeout: float = 60.0, reconnect: bool = True, cls: Type[T] = VoiceClient) -> T: