aboutsummaryrefslogtreecommitdiff
path: root/discord
diff options
context:
space:
mode:
authorJames <[email protected]>2021-04-04 23:05:49 +0100
committerGitHub <[email protected]>2021-04-04 18:05:49 -0400
commit34ab772653152d8e448f21cc4ccb9990edafae73 (patch)
treebb46bf3f14be9a672cd37358ad0d6948c6f49cfd /discord
parentFlatten AsyncIterator.flatten (diff)
downloaddiscord.py-34ab772653152d8e448f21cc4ccb9990edafae73.tar.xz
discord.py-34ab772653152d8e448f21cc4ccb9990edafae73.zip
Use typing.Protocol instead of abc.ABCMeta
Diffstat (limited to 'discord')
-rw-r--r--discord/abc.py112
-rw-r--r--discord/ext/commands/converter.py58
-rw-r--r--discord/ext/commands/core.py5
3 files changed, 87 insertions, 88 deletions
diff --git a/discord/abc.py b/discord/abc.py
index c3735f6c..4930ae31 100644
--- a/discord/abc.py
+++ b/discord/abc.py
@@ -22,10 +22,12 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
-import abc
+from __future__ import annotations
+
import sys
import copy
import asyncio
+from typing import TYPE_CHECKING, Optional, Protocol, runtime_checkable
from .iterators import HistoryIterator
from .context_managers import Typing
@@ -39,13 +41,22 @@ from .file import File
from .voice_client import VoiceClient, VoiceProtocol
from . import utils
+if TYPE_CHECKING:
+ from datetime import datetime
+
+ from .user import ClientUser
+
+
class _Undefined:
def __repr__(self):
return 'see-below'
+
_undefined = _Undefined()
-class Snowflake(metaclass=abc.ABCMeta):
+
+@runtime_checkable
+class Snowflake(Protocol):
"""An ABC that details the common operations on a Discord model.
Almost all :ref:`Discord models <discord_api_models>` meet this
@@ -60,27 +71,16 @@ class Snowflake(metaclass=abc.ABCMeta):
The model's unique ID.
"""
__slots__ = ()
+ id: int
@property
- @abc.abstractmethod
- def created_at(self):
+ def created_at(self) -> datetime:
""":class:`datetime.datetime`: Returns the model's creation time as a naive datetime in UTC."""
raise NotImplementedError
- @classmethod
- def __subclasshook__(cls, C):
- if cls is Snowflake:
- mro = C.__mro__
- for attr in ('created_at', 'id'):
- for base in mro:
- if attr in base.__dict__:
- break
- else:
- return NotImplemented
- return True
- return NotImplemented
-class User(metaclass=abc.ABCMeta):
+@runtime_checkable
+class User(Snowflake, Protocol):
"""An ABC that details the common operations on a Discord user.
The following implement this ABC:
@@ -104,35 +104,24 @@ class User(metaclass=abc.ABCMeta):
"""
__slots__ = ()
+ name: str
+ discriminator: str
+ avatar: Optional[str]
+ bot: bool
+
@property
- @abc.abstractmethod
- def display_name(self):
+ def display_name(self) -> str:
""":class:`str`: Returns the user's display name."""
raise NotImplementedError
@property
- @abc.abstractmethod
- def mention(self):
+ def mention(self) -> str:
""":class:`str`: Returns a string that allows you to mention the given user."""
raise NotImplementedError
- @classmethod
- def __subclasshook__(cls, C):
- if cls is User:
- if Snowflake.__subclasshook__(C) is NotImplemented:
- return NotImplemented
-
- mro = C.__mro__
- for attr in ('display_name', 'mention', 'name', 'avatar', 'discriminator', 'bot'):
- for base in mro:
- if attr in base.__dict__:
- break
- else:
- return NotImplemented
- return True
- return NotImplemented
-class PrivateChannel(metaclass=abc.ABCMeta):
+@runtime_checkable
+class PrivateChannel(Snowflake, Protocol):
"""An ABC that details the common operations on a private Discord channel.
The following implement this ABC:
@@ -149,18 +138,8 @@ class PrivateChannel(metaclass=abc.ABCMeta):
"""
__slots__ = ()
- @classmethod
- def __subclasshook__(cls, C):
- if cls is PrivateChannel:
- if Snowflake.__subclasshook__(C) is NotImplemented:
- return NotImplemented
+ me: ClientUser
- mro = C.__mro__
- for base in mro:
- if 'me' in base.__dict__:
- return True
- return NotImplemented
- return NotImplemented
class _Overwrites:
__slots__ = ('id', 'allow', 'deny', 'type')
@@ -179,7 +158,8 @@ class _Overwrites:
'type': self.type,
}
-class GuildChannel:
+
+class GuildChannel(Protocol):
"""An ABC that details the common operations on a Discord guild channel.
The following implement this ABC:
@@ -190,6 +170,11 @@ class GuildChannel:
This ABC must also implement :class:`~discord.abc.Snowflake`.
+ Note
+ ----
+ This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass`
+ checks.
+
Attributes
-----------
name: :class:`str`
@@ -826,14 +811,13 @@ class GuildChannel:
lock_permissions = kwargs.get('sync_permissions', False)
reason = kwargs.get('reason')
for index, channel in enumerate(channels):
- d = { 'id': channel.id, 'position': index }
+ d = {'id': channel.id, 'position': index}
if parent_id is not ... and channel.id == self.id:
d.update(parent_id=parent_id, lock_permissions=lock_permissions)
payload.append(d)
await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason)
-
async def create_invite(self, *, reason=None, **fields):
"""|coro|
@@ -908,7 +892,8 @@ class GuildChannel:
return result
-class Messageable(metaclass=abc.ABCMeta):
+
+class Messageable(Protocol):
"""An ABC that details the common operations on a model that can send messages.
The following implement this ABC:
@@ -919,11 +904,16 @@ class Messageable(metaclass=abc.ABCMeta):
- :class:`~discord.User`
- :class:`~discord.Member`
- :class:`~discord.ext.commands.Context`
+
+
+ Note
+ ----
+ This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass`
+ checks.
"""
__slots__ = ()
- @abc.abstractmethod
async def _get_channel(self):
raise NotImplementedError
@@ -1060,8 +1050,8 @@ class Messageable(metaclass=abc.ABCMeta):
f.close()
else:
data = await state.http.send_message(channel.id, content, tts=tts, embed=embed,
- nonce=nonce, allowed_mentions=allowed_mentions,
- message_reference=reference)
+ nonce=nonce, allowed_mentions=allowed_mentions,
+ message_reference=reference)
ret = state.create_message(channel=channel, data=data)
if delete_after is not None:
@@ -1213,21 +1203,25 @@ class Messageable(metaclass=abc.ABCMeta):
"""
return HistoryIterator(self, limit=limit, before=before, after=after, around=around, oldest_first=oldest_first)
-class Connectable(metaclass=abc.ABCMeta):
+
+class Connectable(Protocol):
"""An ABC that details the common operations on a channel that can
connect to a voice server.
The following implement this ABC:
- :class:`~discord.VoiceChannel`
+
+ Note
+ ----
+ This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass`
+ checks.
"""
__slots__ = ()
- @abc.abstractmethod
def _get_voice_client_key(self):
raise NotImplementedError
- @abc.abstractmethod
def _get_voice_state_pair(self):
raise NotImplementedError
@@ -1286,6 +1280,6 @@ class Connectable(metaclass=abc.ABCMeta):
except Exception:
# we don't care if disconnect failed because connection failed
pass
- raise # re-raise
+ raise # re-raise
return voice
diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py
index 010acca9..d5a30a83 100644
--- a/discord/ext/commands/converter.py
+++ b/discord/ext/commands/converter.py
@@ -22,14 +22,19 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
+from __future__ import annotations
+
import re
import inspect
-import typing
+from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, Union, runtime_checkable
import discord
-
from .errors import *
+if TYPE_CHECKING:
+ from .context import Context
+
+
__all__ = (
'Converter',
'MemberConverter',
@@ -54,6 +59,7 @@ __all__ = (
'Greedy',
)
+
def _get_from_guilds(bot, getter, argument):
result = None
for guild in bot.guilds:
@@ -62,9 +68,13 @@ def _get_from_guilds(bot, getter, argument):
return result
return result
+
_utils_get = discord.utils.get
+T = TypeVar("T")
-class Converter:
+
+@runtime_checkable
+class Converter(Protocol[T]):
"""The base class of custom converters that require the :class:`.Context`
to be passed to be useful.
@@ -75,7 +85,7 @@ class Converter:
method to do its conversion logic. This method must be a :ref:`coroutine <coroutine>`.
"""
- async def convert(self, ctx, argument):
+ async def convert(self, ctx: Context, argument: str) -> T:
"""|coro|
The method to override to do conversion logic.
@@ -100,7 +110,7 @@ class Converter:
"""
raise NotImplementedError('Derived classes need to implement this.')
-class IDConverter(Converter):
+class IDConverter(Converter[T]):
def __init__(self):
self._id_regex = re.compile(r'([0-9]{15,20})$')
super().__init__()
@@ -108,7 +118,7 @@ class IDConverter(Converter):
def _get_id_match(self, argument):
return self._id_regex.match(argument)
-class MemberConverter(IDConverter):
+class MemberConverter(IDConverter[discord.Member]):
"""Converts to a :class:`~discord.Member`.
All lookups are via the local guild. If in a DM context, then the lookup
@@ -194,7 +204,7 @@ class MemberConverter(IDConverter):
return result
-class UserConverter(IDConverter):
+class UserConverter(IDConverter[discord.User]):
"""Converts to a :class:`~discord.User`.
All lookups are via the global user cache.
@@ -253,7 +263,7 @@ class UserConverter(IDConverter):
return result
-class PartialMessageConverter(Converter):
+class PartialMessageConverter(Converter[discord.PartialMessage], Generic[T]):
"""Converts to a :class:`discord.PartialMessage`.
.. versionadded:: 1.7
@@ -284,7 +294,7 @@ class PartialMessageConverter(Converter):
raise ChannelNotFound(channel_id)
return discord.PartialMessage(channel=channel, id=message_id)
-class MessageConverter(PartialMessageConverter):
+class MessageConverter(PartialMessageConverter[discord.Message]):
"""Converts to a :class:`discord.Message`.
.. versionadded:: 1.1
@@ -313,7 +323,7 @@ class MessageConverter(PartialMessageConverter):
except discord.Forbidden:
raise ChannelNotReadable(channel)
-class TextChannelConverter(IDConverter):
+class TextChannelConverter(IDConverter[discord.TextChannel]):
"""Converts to a :class:`~discord.TextChannel`.
All lookups are via the local guild. If in a DM context, then the lookup
@@ -355,7 +365,7 @@ class TextChannelConverter(IDConverter):
return result
-class VoiceChannelConverter(IDConverter):
+class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
"""Converts to a :class:`~discord.VoiceChannel`.
All lookups are via the local guild. If in a DM context, then the lookup
@@ -396,7 +406,7 @@ class VoiceChannelConverter(IDConverter):
return result
-class StageChannelConverter(IDConverter):
+class StageChannelConverter(IDConverter[discord.StageChannel]):
"""Converts to a :class:`~discord.StageChannel`.
.. versionadded:: 1.7
@@ -436,7 +446,7 @@ class StageChannelConverter(IDConverter):
return result
-class CategoryChannelConverter(IDConverter):
+class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
"""Converts to a :class:`~discord.CategoryChannel`.
All lookups are via the local guild. If in a DM context, then the lookup
@@ -478,7 +488,7 @@ class CategoryChannelConverter(IDConverter):
return result
-class StoreChannelConverter(IDConverter):
+class StoreChannelConverter(IDConverter[discord.StoreChannel]):
"""Converts to a :class:`~discord.StoreChannel`.
All lookups are via the local guild. If in a DM context, then the lookup
@@ -519,7 +529,7 @@ class StoreChannelConverter(IDConverter):
return result
-class ColourConverter(Converter):
+class ColourConverter(Converter[discord.Colour]):
"""Converts to a :class:`~discord.Colour`.
.. versionchanged:: 1.5
@@ -603,7 +613,7 @@ class ColourConverter(Converter):
ColorConverter = ColourConverter
-class RoleConverter(IDConverter):
+class RoleConverter(IDConverter[discord.Role]):
"""Converts to a :class:`~discord.Role`.
All lookups are via the local guild. If in a DM context, then the lookup
@@ -633,12 +643,12 @@ class RoleConverter(IDConverter):
raise RoleNotFound(argument)
return result
-class GameConverter(Converter):
+class GameConverter(Converter[discord.Game]):
"""Converts to :class:`~discord.Game`."""
async def convert(self, ctx, argument):
return discord.Game(name=argument)
-class InviteConverter(Converter):
+class InviteConverter(Converter[discord.Invite]):
"""Converts to a :class:`~discord.Invite`.
This is done via an HTTP request using :meth:`.Bot.fetch_invite`.
@@ -653,7 +663,7 @@ class InviteConverter(Converter):
except Exception as exc:
raise BadInviteArgument() from exc
-class GuildConverter(IDConverter):
+class GuildConverter(IDConverter[discord.Guild]):
"""Converts to a :class:`~discord.Guild`.
The lookup strategy is as follows (in order):
@@ -679,7 +689,7 @@ class GuildConverter(IDConverter):
raise GuildNotFound(argument)
return result
-class EmojiConverter(IDConverter):
+class EmojiConverter(IDConverter[discord.Emoji]):
"""Converts to a :class:`~discord.Emoji`.
All lookups are done for the local guild first, if available. If that lookup
@@ -722,7 +732,7 @@ class EmojiConverter(IDConverter):
return result
-class PartialEmojiConverter(Converter):
+class PartialEmojiConverter(Converter[discord.PartialEmoji]):
"""Converts to a :class:`~discord.PartialEmoji`.
This is done by extracting the animated flag, name and ID from the emoji.
@@ -743,7 +753,7 @@ class PartialEmojiConverter(Converter):
raise PartialEmojiConversionFailure(argument)
-class clean_content(Converter):
+class clean_content(Converter[str]):
"""Converts the argument to mention scrubbed version of
said content.
@@ -775,7 +785,7 @@ class clean_content(Converter):
if self.fix_channel_mentions and ctx.guild:
def resolve_channel(id, *, _get=ctx.guild.get_channel):
ch = _get(id)
- return (f'<#{id}>'), ('#' + ch.name if ch else '#deleted-channel')
+ return f'<#{id}>', ('#' + ch.name if ch else '#deleted-channel')
transformations.update(resolve_channel(channel) for channel in message.raw_channel_mentions)
@@ -842,7 +852,7 @@ class _Greedy:
if converter is str or converter is type(None) or converter is _Greedy:
raise TypeError(f'Greedy[{converter.__name__}] is invalid.')
- if getattr(converter, '__origin__', None) is typing.Union and type(None) in converter.__args__:
+ if getattr(converter, '__origin__', None) is Union and type(None) in converter.__args__:
raise TypeError(f'Greedy[{converter!r}] is invalid.')
return self.__class__(converter=converter)
diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py
index d3badff4..ec2e7deb 100644
--- a/discord/ext/commands/core.py
+++ b/discord/ext/commands/core.py
@@ -448,11 +448,6 @@ class Command(_BaseCommand):
instance = converter()
ret = await instance.convert(ctx, argument)
return ret
- else:
- method = getattr(converter, 'convert', None)
- if method is not None and inspect.ismethod(method):
- ret = await method(ctx, argument)
- return ret
elif isinstance(converter, converters.Converter):
ret = await converter.convert(ctx, argument)
return ret