aboutsummaryrefslogtreecommitdiff
path: root/discord/ext
diff options
context:
space:
mode:
authorJosh <[email protected]>2021-08-20 09:51:26 +1000
committerGitHub <[email protected]>2021-08-19 19:51:26 -0400
commitf3cb19742914df4c7019f51393424fd88700a53e (patch)
treea6d714d1d78f0fbbf825bb15dda03de81fa06f0a /discord/ext
parentdefault to 0 instead of 15 for Guild.sticker_limit (diff)
downloaddiscord.py-f3cb19742914df4c7019f51393424fd88700a53e.tar.xz
discord.py-f3cb19742914df4c7019f51393424fd88700a53e.zip
[commands][types] Type hint commands-ext
Diffstat (limited to 'discord/ext')
-rw-r--r--discord/ext/commands/_types.py20
-rw-r--r--discord/ext/commands/bot.py132
-rw-r--r--discord/ext/commands/cog.py82
-rw-r--r--discord/ext/commands/context.py149
-rw-r--r--discord/ext/commands/core.py557
-rw-r--r--discord/ext/commands/help.py8
6 files changed, 636 insertions, 312 deletions
diff --git a/discord/ext/commands/_types.py b/discord/ext/commands/_types.py
index 8c3c53a2..9b155987 100644
--- a/discord/ext/commands/_types.py
+++ b/discord/ext/commands/_types.py
@@ -22,6 +22,26 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
+
+from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union
+
+
+if TYPE_CHECKING:
+ from .context import Context
+ from .cog import Cog
+ from .errors import CommandError
+
+T = TypeVar('T')
+
+Coro = Coroutine[Any, Any, T]
+MaybeCoro = Union[T, Coro[T]]
+CoroFunc = Callable[..., Coro[Any]]
+
+Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]]
+Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]]
+Error = Union[Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]]
+
+
# This is merely a tag type to avoid circular import issues.
# Yes, this is a terrible solution but ultimately it is the only solution.
class _BaseCommand:
diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py
index 7c49bf96..ba108153 100644
--- a/discord/ext/commands/bot.py
+++ b/discord/ext/commands/bot.py
@@ -22,13 +22,18 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
+from __future__ import annotations
+
+
import asyncio
import collections
+import collections.abc
import inspect
import importlib.util
import sys
import traceback
import types
+from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union
import discord
@@ -39,6 +44,15 @@ from . import errors
from .help import HelpCommand, DefaultHelpCommand
from .cog import Cog
+if TYPE_CHECKING:
+ import importlib.machinery
+
+ from discord.message import Message
+ from ._types import (
+ Check,
+ CoroFunc,
+ )
+
__all__ = (
'when_mentioned',
'when_mentioned_or',
@@ -46,14 +60,21 @@ __all__ = (
'AutoShardedBot',
)
-def when_mentioned(bot, msg):
+MISSING: Any = discord.utils.MISSING
+
+T = TypeVar('T')
+CFT = TypeVar('CFT', bound='CoroFunc')
+CXT = TypeVar('CXT', bound='Context')
+
+def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
"""A callable that implements a command prefix equivalent to being mentioned.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
"""
- return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> ']
+ # bot.user will never be None when this is called
+ return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
-def when_mentioned_or(*prefixes):
+def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
"""A callable that implements when mentioned or other prefixes provided.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
@@ -89,7 +110,7 @@ def when_mentioned_or(*prefixes):
return inner
-def _is_submodule(parent, child):
+def _is_submodule(parent: str, child: str) -> bool:
return parent == child or child.startswith(parent + ".")
class _DefaultRepr:
@@ -102,10 +123,10 @@ class BotBase(GroupMixin):
def __init__(self, command_prefix, help_command=_default, description=None, **options):
super().__init__(**options)
self.command_prefix = command_prefix
- self.extra_events = {}
- self.__cogs = {}
- self.__extensions = {}
- self._checks = []
+ self.extra_events: Dict[str, List[CoroFunc]] = {}
+ self.__cogs: Dict[str, Cog] = {}
+ self.__extensions: Dict[str, types.ModuleType] = {}
+ self._checks: List[Check] = []
self._check_once = []
self._before_invoke = None
self._after_invoke = None
@@ -128,13 +149,14 @@ class BotBase(GroupMixin):
# internal helpers
- def dispatch(self, event_name, *args, **kwargs):
- super().dispatch(event_name, *args, **kwargs)
+ def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None:
+ # super() will resolve to Client
+ super().dispatch(event_name, *args, **kwargs) # type: ignore
ev = 'on_' + event_name
for event in self.extra_events.get(ev, []):
- self._schedule_event(event, ev, *args, **kwargs)
+ self._schedule_event(event, ev, *args, **kwargs) # type: ignore
- async def close(self):
+ async def close(self) -> None:
for extension in tuple(self.__extensions):
try:
self.unload_extension(extension)
@@ -147,9 +169,9 @@ class BotBase(GroupMixin):
except Exception:
pass
- await super().close()
+ await super().close() # type: ignore
- async def on_command_error(self, context, exception):
+ async def on_command_error(self, context: Context, exception: errors.CommandError) -> None:
"""|coro|
The default command error handler provided by the bot.
@@ -175,7 +197,7 @@ class BotBase(GroupMixin):
# global check registration
- def check(self, func):
+ def check(self, func: T) -> T:
r"""A decorator that adds a global check to the bot.
A global check is similar to a :func:`.check` that is applied
@@ -200,10 +222,11 @@ class BotBase(GroupMixin):
return ctx.command.qualified_name in allowed_commands
"""
- self.add_check(func)
+ # T was used instead of Check to ensure the type matches on return
+ self.add_check(func) # type: ignore
return func
- def add_check(self, func, *, call_once=False):
+ def add_check(self, func: Check, *, call_once: bool = False) -> None:
"""Adds a global check to the bot.
This is the non-decorator interface to :meth:`.check`
@@ -223,7 +246,7 @@ class BotBase(GroupMixin):
else:
self._checks.append(func)
- def remove_check(self, func, *, call_once=False):
+ def remove_check(self, func: Check, *, call_once: bool = False) -> None:
"""Removes a global check from the bot.
This function is idempotent and will not raise an exception
@@ -244,7 +267,7 @@ class BotBase(GroupMixin):
except ValueError:
pass
- def check_once(self, func):
+ def check_once(self, func: CFT) -> CFT:
r"""A decorator that adds a "call once" global check to the bot.
Unlike regular global checks, this one is called only once
@@ -282,15 +305,16 @@ class BotBase(GroupMixin):
self.add_check(func, call_once=True)
return func
- async def can_run(self, ctx, *, call_once=False):
+ async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool:
data = self._check_once if call_once else self._checks
if len(data) == 0:
return True
- return await discord.utils.async_all(f(ctx) for f in data)
+ # type-checker doesn't distinguish between functions and methods
+ return await discord.utils.async_all(f(ctx) for f in data) # type: ignore
- async def is_owner(self, user):
+ async def is_owner(self, user: discord.User) -> bool:
"""|coro|
Checks if a :class:`~discord.User` or :class:`~discord.Member` is the owner of
@@ -319,7 +343,8 @@ class BotBase(GroupMixin):
elif self.owner_ids:
return user.id in self.owner_ids
else:
- app = await self.application_info()
+
+ app = await self.application_info() # type: ignore
if app.team:
self.owner_ids = ids = {m.id for m in app.team.members}
return user.id in ids
@@ -327,7 +352,7 @@ class BotBase(GroupMixin):
self.owner_id = owner_id = app.owner.id
return user.id == owner_id
- def before_invoke(self, coro):
+ def before_invoke(self, coro: CFT) -> CFT:
"""A decorator that registers a coroutine as a pre-invoke hook.
A pre-invoke hook is called directly before the command is
@@ -359,7 +384,7 @@ class BotBase(GroupMixin):
self._before_invoke = coro
return coro
- def after_invoke(self, coro):
+ def after_invoke(self, coro: CFT) -> CFT:
r"""A decorator that registers a coroutine as a post-invoke hook.
A post-invoke hook is called directly after the command is
@@ -394,14 +419,14 @@ class BotBase(GroupMixin):
# listener registration
- def add_listener(self, func, name=None):
+ def add_listener(self, func: CoroFunc, name: str = MISSING) -> None:
"""The non decorator alternative to :meth:`.listen`.
Parameters
-----------
func: :ref:`coroutine <coroutine>`
The function to call.
- name: Optional[:class:`str`]
+ name: :class:`str`
The name of the event to listen for. Defaults to ``func.__name__``.
Example
@@ -416,7 +441,7 @@ class BotBase(GroupMixin):
bot.add_listener(my_message, 'on_message')
"""
- name = func.__name__ if name is None else name
+ name = func.__name__ if name is MISSING else name
if not asyncio.iscoroutinefunction(func):
raise TypeError('Listeners must be coroutines')
@@ -426,7 +451,7 @@ class BotBase(GroupMixin):
else:
self.extra_events[name] = [func]
- def remove_listener(self, func, name=None):
+ def remove_listener(self, func: CoroFunc, name: str = MISSING) -> None:
"""Removes a listener from the pool of listeners.
Parameters
@@ -438,7 +463,7 @@ class BotBase(GroupMixin):
``func.__name__``.
"""
- name = func.__name__ if name is None else name
+ name = func.__name__ if name is MISSING else name
if name in self.extra_events:
try:
@@ -446,7 +471,7 @@ class BotBase(GroupMixin):
except ValueError:
pass
- def listen(self, name=None):
+ def listen(self, name: str = MISSING) -> Callable[[CFT], CFT]:
"""A decorator that registers another function as an external
event listener. Basically this allows you to listen to multiple
events from different places e.g. such as :func:`.on_ready`
@@ -476,7 +501,7 @@ class BotBase(GroupMixin):
The function being listened to is not a coroutine.
"""
- def decorator(func):
+ def decorator(func: CFT) -> CFT:
self.add_listener(func, name)
return func
@@ -528,7 +553,7 @@ class BotBase(GroupMixin):
cog = cog._inject(self)
self.__cogs[cog_name] = cog
- def get_cog(self, name):
+ def get_cog(self, name: str) -> Optional[Cog]:
"""Gets the cog instance requested.
If the cog is not found, ``None`` is returned instead.
@@ -547,7 +572,7 @@ class BotBase(GroupMixin):
"""
return self.__cogs.get(name)
- def remove_cog(self, name):
+ def remove_cog(self, name: str) -> Optional[Cog]:
"""Removes a cog from the bot and returns it.
All registered commands and event listeners that the
@@ -578,13 +603,13 @@ class BotBase(GroupMixin):
return cog
@property
- def cogs(self):
+ def cogs(self) -> Mapping[str, Cog]:
"""Mapping[:class:`str`, :class:`Cog`]: A read-only mapping of cog name to cog."""
return types.MappingProxyType(self.__cogs)
# extensions
- def _remove_module_references(self, name):
+ def _remove_module_references(self, name: str) -> None:
# find all references to the module
# remove the cogs registered from the module
for cogname, cog in self.__cogs.copy().items():
@@ -608,7 +633,7 @@ class BotBase(GroupMixin):
for index in reversed(remove):
del event_list[index]
- def _call_module_finalizers(self, lib, key):
+ def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
try:
func = getattr(lib, 'teardown')
except AttributeError:
@@ -626,12 +651,12 @@ class BotBase(GroupMixin):
if _is_submodule(name, module):
del sys.modules[module]
- def _load_from_module_spec(self, spec, key):
+ def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None:
# precondition: key not in self.__extensions
lib = importlib.util.module_from_spec(spec)
sys.modules[key] = lib
try:
- spec.loader.exec_module(lib)
+ spec.loader.exec_module(lib) # type: ignore
except Exception as e:
del sys.modules[key]
raise errors.ExtensionFailed(key, e) from e
@@ -652,13 +677,13 @@ class BotBase(GroupMixin):
else:
self.__extensions[key] = lib
- def _resolve_name(self, name, package):
+ def _resolve_name(self, name: str, package: Optional[str]) -> str:
try:
return importlib.util.resolve_name(name, package)
except ImportError:
raise errors.ExtensionNotFound(name)
- def load_extension(self, name, *, package=None):
+ def load_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Loads an extension.
An extension is a python module that contains commands, cogs, or
@@ -705,7 +730,7 @@ class BotBase(GroupMixin):
self._load_from_module_spec(spec, name)
- def unload_extension(self, name, *, package=None):
+ def unload_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Unloads an extension.
When the extension is unloaded, all commands, listeners, and cogs are
@@ -746,7 +771,7 @@ class BotBase(GroupMixin):
self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, name)
- def reload_extension(self, name, *, package=None):
+ def reload_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Atomically reloads an extension.
This replaces the extension with the same extension, only refreshed. This is
@@ -802,7 +827,7 @@ class BotBase(GroupMixin):
# if the load failed, the remnants should have been
# cleaned from the load_extension function call
# so let's load it from our old compiled library.
- lib.setup(self)
+ lib.setup(self) # type: ignore
self.__extensions[name] = lib
# revert sys.modules back to normal and raise back to caller
@@ -810,18 +835,18 @@ class BotBase(GroupMixin):
raise
@property
- def extensions(self):
+ def extensions(self) -> Mapping[str, types.ModuleType]:
"""Mapping[:class:`str`, :class:`py:types.ModuleType`]: A read-only mapping of extension name to extension."""
return types.MappingProxyType(self.__extensions)
# help command stuff
@property
- def help_command(self):
+ def help_command(self) -> Optional[HelpCommand]:
return self._help_command
@help_command.setter
- def help_command(self, value):
+ def help_command(self, value: Optional[HelpCommand]) -> None:
if value is not None:
if not isinstance(value, HelpCommand):
raise TypeError('help_command must be a subclass of HelpCommand')
@@ -837,7 +862,7 @@ class BotBase(GroupMixin):
# command processing
- async def get_prefix(self, message):
+ async def get_prefix(self, message: Message) -> Union[List[str], str]:
"""|coro|
Retrieves the prefix the bot is listening to
@@ -875,7 +900,7 @@ class BotBase(GroupMixin):
return ret
- async def get_context(self, message, *, cls=Context):
+ async def get_context(self, message: Message, *, cls: Type[CXT] = Context) -> CXT:
r"""|coro|
Returns the invocation context from the message.
@@ -908,7 +933,7 @@ class BotBase(GroupMixin):
view = StringView(message.content)
ctx = cls(prefix=None, view=view, bot=self, message=message)
- if message.author.id == self.user.id:
+ if message.author.id == self.user.id: # type: ignore
return ctx
prefix = await self.get_prefix(message)
@@ -945,11 +970,12 @@ class BotBase(GroupMixin):
invoker = view.get_word()
ctx.invoked_with = invoker
- ctx.prefix = invoked_prefix
+ # type-checker fails to narrow invoked_prefix type.
+ ctx.prefix = invoked_prefix # type: ignore
ctx.command = self.all_commands.get(invoker)
return ctx
- async def invoke(self, ctx):
+ async def invoke(self, ctx: Context) -> None:
"""|coro|
Invokes the command given under the invocation context and
@@ -975,7 +1001,7 @@ class BotBase(GroupMixin):
exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found')
self.dispatch('command_error', ctx, exc)
- async def process_commands(self, message):
+ async def process_commands(self, message: Message) -> None:
"""|coro|
This function processes the commands that have been registered
diff --git a/discord/ext/commands/cog.py b/discord/ext/commands/cog.py
index da428cff..9931557d 100644
--- a/discord/ext/commands/cog.py
+++ b/discord/ext/commands/cog.py
@@ -21,15 +21,30 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
+from __future__ import annotations
import inspect
+import discord.utils
+
+from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type
+
from ._types import _BaseCommand
+if TYPE_CHECKING:
+ from .bot import BotBase
+ from .context import Context
+ from .core import Command
+
__all__ = (
'CogMeta',
'Cog',
)
+CogT = TypeVar('CogT', bound='Cog')
+FuncT = TypeVar('FuncT', bound=Callable[..., Any])
+
+MISSING: Any = discord.utils.MISSING
+
class CogMeta(type):
"""A metaclass for defining a cog.
@@ -89,8 +104,12 @@ class CogMeta(type):
async def bar(self, ctx):
pass # hidden -> False
"""
+ __cog_name__: str
+ __cog_settings__: Dict[str, Any]
+ __cog_commands__: List[Command]
+ __cog_listeners__: List[Tuple[str, str]]
- def __new__(cls, *args, **kwargs):
+ def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta:
name, bases, attrs = args
attrs['__cog_name__'] = kwargs.pop('name', name)
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
@@ -143,14 +162,14 @@ class CogMeta(type):
new_cls.__cog_listeners__ = listeners_as_list
return new_cls
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args)
@classmethod
- def qualified_name(cls):
+ def qualified_name(cls) -> str:
return cls.__cog_name__
-def _cog_special_method(func):
+def _cog_special_method(func: FuncT) -> FuncT:
func.__cog_special_method__ = None
return func
@@ -164,8 +183,12 @@ class Cog(metaclass=CogMeta):
When inheriting from this class, the options shown in :class:`CogMeta`
are equally valid here.
"""
+ __cog_name__: ClassVar[str]
+ __cog_settings__: ClassVar[Dict[str, Any]]
+ __cog_commands__: ClassVar[List[Command]]
+ __cog_listeners__: ClassVar[List[Tuple[str, str]]]
- def __new__(cls, *args, **kwargs):
+ def __new__(cls: Type[CogT], *args: Any, **kwargs: Any) -> CogT:
# For issue 426, we need to store a copy of the command objects
# since we modify them to inject `self` to them.
# To do this, we need to interfere with the Cog creation process.
@@ -173,7 +196,8 @@ class Cog(metaclass=CogMeta):
cmd_attrs = cls.__cog_settings__
# Either update the command with the cog provided defaults or copy it.
- self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__)
+ # r.e type ignore, type-checker complains about overriding a ClassVar
+ self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore
lookup = {
cmd.qualified_name: cmd
@@ -186,15 +210,15 @@ class Cog(metaclass=CogMeta):
parent = command.parent
if parent is not None:
# Get the latest parent reference
- parent = lookup[parent.qualified_name]
+ parent = lookup[parent.qualified_name] # type: ignore
# Update our parent's reference to our self
- parent.remove_command(command.name)
- parent.add_command(command)
+ parent.remove_command(command.name) # type: ignore
+ parent.add_command(command) # type: ignore
return self
- def get_commands(self):
+ def get_commands(self) -> List[Command]:
r"""
Returns
--------
@@ -209,20 +233,20 @@ class Cog(metaclass=CogMeta):
return [c for c in self.__cog_commands__ if c.parent is None]
@property
- def qualified_name(self):
+ def qualified_name(self) -> str:
""":class:`str`: Returns the cog's specified name, not the class name."""
return self.__cog_name__
@property
- def description(self):
+ def description(self) -> str:
""":class:`str`: Returns the cog's description, typically the cleaned docstring."""
return self.__cog_description__
@description.setter
- def description(self, description):
+ def description(self, description: str) -> None:
self.__cog_description__ = description
- def walk_commands(self):
+ def walk_commands(self) -> Generator[Command, None, None]:
"""An iterator that recursively walks through this cog's commands and subcommands.
Yields
@@ -237,7 +261,7 @@ class Cog(metaclass=CogMeta):
if isinstance(command, GroupMixin):
yield from command.walk_commands()
- def get_listeners(self):
+ def get_listeners(self) -> List[Tuple[str, Callable[..., Any]]]:
"""Returns a :class:`list` of (name, function) listener pairs that are defined in this cog.
Returns
@@ -248,12 +272,12 @@ class Cog(metaclass=CogMeta):
return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__]
@classmethod
- def _get_overridden_method(cls, method):
+ def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]:
"""Return None if the method is not overridden. Otherwise returns the overridden method."""
return getattr(method.__func__, '__cog_special_method__', method)
@classmethod
- def listener(cls, name=None):
+ def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]:
"""A decorator that marks a function as a listener.
This is the cog equivalent of :meth:`.Bot.listen`.
@@ -271,10 +295,10 @@ class Cog(metaclass=CogMeta):
the name.
"""
- if name is not None and not isinstance(name, str):
+ if name is not MISSING and not isinstance(name, str):
raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.')
- def decorator(func):
+ def decorator(func: FuncT) -> FuncT:
actual = func
if isinstance(actual, staticmethod):
actual = actual.__func__
@@ -293,7 +317,7 @@ class Cog(metaclass=CogMeta):
return func
return decorator
- def has_error_handler(self):
+ def has_error_handler(self) -> bool:
""":class:`bool`: Checks whether the cog has an error handler.
.. versionadded:: 1.7
@@ -301,7 +325,7 @@ class Cog(metaclass=CogMeta):
return not hasattr(self.cog_command_error.__func__, '__cog_special_method__')
@_cog_special_method
- def cog_unload(self):
+ def cog_unload(self) -> None:
"""A special method that is called when the cog gets removed.
This function **cannot** be a coroutine. It must be a regular
@@ -312,7 +336,7 @@ class Cog(metaclass=CogMeta):
pass
@_cog_special_method
- def bot_check_once(self, ctx):
+ def bot_check_once(self, ctx: Context) -> bool:
"""A special method that registers as a :meth:`.Bot.check_once`
check.
@@ -322,7 +346,7 @@ class Cog(metaclass=CogMeta):
return True
@_cog_special_method
- def bot_check(self, ctx):
+ def bot_check(self, ctx: Context) -> bool:
"""A special method that registers as a :meth:`.Bot.check`
check.
@@ -332,7 +356,7 @@ class Cog(metaclass=CogMeta):
return True
@_cog_special_method
- def cog_check(self, ctx):
+ def cog_check(self, ctx: Context) -> bool:
"""A special method that registers as a :func:`~discord.ext.commands.check`
for every command and subcommand in this cog.
@@ -342,7 +366,7 @@ class Cog(metaclass=CogMeta):
return True
@_cog_special_method
- async def cog_command_error(self, ctx, error):
+ async def cog_command_error(self, ctx: Context, error: Exception) -> None:
"""A special method that is called whenever an error
is dispatched inside this cog.
@@ -361,7 +385,7 @@ class Cog(metaclass=CogMeta):
pass
@_cog_special_method
- async def cog_before_invoke(self, ctx):
+ async def cog_before_invoke(self, ctx: Context) -> None:
"""A special method that acts as a cog local pre-invoke hook.
This is similar to :meth:`.Command.before_invoke`.
@@ -376,7 +400,7 @@ class Cog(metaclass=CogMeta):
pass
@_cog_special_method
- async def cog_after_invoke(self, ctx):
+ async def cog_after_invoke(self, ctx: Context) -> None:
"""A special method that acts as a cog local post-invoke hook.
This is similar to :meth:`.Command.after_invoke`.
@@ -390,7 +414,7 @@ class Cog(metaclass=CogMeta):
"""
pass
- def _inject(self, bot):
+ def _inject(self: CogT, bot: BotBase) -> CogT:
cls = self.__class__
# realistically, the only thing that can cause loading errors
@@ -425,7 +449,7 @@ class Cog(metaclass=CogMeta):
return self
- def _eject(self, bot):
+ def _eject(self, bot: BotBase) -> None:
cls = self.__class__
try:
diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py
index c5367c24..e231f0e7 100644
--- a/discord/ext/commands/context.py
+++ b/discord/ext/commands/context.py
@@ -21,16 +21,52 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
+from __future__ import annotations
+
+import inspect
+import re
+
+from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union
import discord.abc
import discord.utils
-import re
+
+from discord.message import Message
+
+if TYPE_CHECKING:
+ from typing_extensions import ParamSpec
+
+ from discord.abc import MessageableChannel
+ from discord.guild import Guild
+ from discord.member import Member
+ from discord.state import ConnectionState
+ from discord.user import ClientUser, User
+ from discord.voice_client import VoiceProtocol
+
+ from .bot import Bot, AutoShardedBot
+ from .cog import Cog
+ from .core import Command
+ from .help import HelpCommand
+ from .view import StringView
__all__ = (
'Context',
)
-class Context(discord.abc.Messageable):
+MISSING: Any = discord.utils.MISSING
+
+
+T = TypeVar('T')
+BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
+CogT = TypeVar('CogT', bound="Cog")
+
+if TYPE_CHECKING:
+ P = ParamSpec('P')
+else:
+ P = TypeVar('P')
+
+
+class Context(discord.abc.Messageable, Generic[BotT]):
r"""Represents the context in which a command is being invoked under.
This class contains a lot of meta data to help you understand more about
@@ -58,11 +94,11 @@ class Context(discord.abc.Messageable):
This is only of use for within converters.
.. versionadded:: 2.0
- prefix: :class:`str`
+ prefix: Optional[:class:`str`]
The prefix that was used to invoke the command.
- command: :class:`Command`
+ command: Optional[:class:`Command`]
The command that is being invoked currently.
- invoked_with: :class:`str`
+ invoked_with: Optional[:class:`str`]
The command name that triggered this invocation. Useful for finding out
which alias called the command.
invoked_parents: List[:class:`str`]
@@ -73,7 +109,7 @@ class Context(discord.abc.Messageable):
.. versionadded:: 1.7
- invoked_subcommand: :class:`Command`
+ invoked_subcommand: Optional[:class:`Command`]
The subcommand that was invoked.
If no valid subcommand was invoked then this is equal to ``None``.
subcommand_passed: Optional[:class:`str`]
@@ -86,23 +122,38 @@ class Context(discord.abc.Messageable):
or invoked.
"""
- def __init__(self, **attrs):
- self.message = attrs.pop('message', None)
- self.bot = attrs.pop('bot', None)
- self.args = attrs.pop('args', [])
- self.kwargs = attrs.pop('kwargs', {})
- self.prefix = attrs.pop('prefix')
- self.command = attrs.pop('command', None)
- self.view = attrs.pop('view', None)
- self.invoked_with = attrs.pop('invoked_with', None)
- self.invoked_parents = attrs.pop('invoked_parents', [])
- self.invoked_subcommand = attrs.pop('invoked_subcommand', None)
- self.subcommand_passed = attrs.pop('subcommand_passed', None)
- self.command_failed = attrs.pop('command_failed', False)
- self.current_parameter = attrs.pop('current_parameter', None)
- self._state = self.message._state
-
- async def invoke(self, command, /, *args, **kwargs):
+ def __init__(self,
+ *,
+ message: Message,
+ bot: BotT,
+ view: StringView,
+ args: List[Any] = MISSING,
+ kwargs: Dict[str, Any] = MISSING,
+ prefix: Optional[str] = None,
+ command: Optional[Command] = None,
+ invoked_with: Optional[str] = None,
+ invoked_parents: List[str] = MISSING,
+ invoked_subcommand: Optional[Command] = None,
+ subcommand_passed: Optional[str] = None,
+ command_failed: bool = False,
+ current_parameter: Optional[inspect.Parameter] = None,
+ ):
+ self.message: Message = message
+ self.bot: BotT = bot
+ self.args: List[Any] = args or []
+ self.kwargs: Dict[str, Any] = kwargs or {}
+ self.prefix: Optional[str] = prefix
+ self.command: Optional[Command] = command
+ self.view: StringView = view
+ self.invoked_with: Optional[str] = invoked_with
+ self.invoked_parents: List[str] = invoked_parents or []
+ self.invoked_subcommand: Optional[Command] = invoked_subcommand
+ self.subcommand_passed: Optional[str] = subcommand_passed
+ self.command_failed: bool = command_failed
+ self.current_parameter: Optional[inspect.Parameter] = current_parameter
+ self._state: ConnectionState = self.message._state
+
+ async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
r"""|coro|
Calls a command with the arguments given.
@@ -133,17 +184,9 @@ class Context(discord.abc.Messageable):
TypeError
The command argument to invoke is missing.
"""
- arguments = []
- if command.cog is not None:
- arguments.append(command.cog)
-
- arguments.append(self)
- arguments.extend(args)
+ return await command(self, *args, **kwargs)
- ret = await command.callback(*arguments, **kwargs)
- return ret
-
- async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True):
+ async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True) -> None:
"""|coro|
Calls the command again.
@@ -187,7 +230,7 @@ class Context(discord.abc.Messageable):
if restart:
to_call = cmd.root_parent or cmd
- view.index = len(self.prefix)
+ view.index = len(self.prefix or '')
view.previous = 0
self.invoked_parents = []
self.invoked_with = view.get_word() # advance to get the root command
@@ -206,20 +249,23 @@ class Context(discord.abc.Messageable):
self.subcommand_passed = subcommand_passed
@property
- def valid(self):
+ def valid(self) -> bool:
""":class:`bool`: Checks if the invocation context is valid to be invoked with."""
return self.prefix is not None and self.command is not None
- async def _get_channel(self):
+ async def _get_channel(self) -> discord.abc.Messageable:
return self.channel
@property
- def clean_prefix(self):
+ def clean_prefix(self) -> str:
""":class:`str`: The cleaned up invoke prefix. i.e. mentions are ``@name`` instead of ``<@id>``.
.. versionadded:: 2.0
"""
- user = self.guild.me if self.guild else self.bot.user
+ if self.prefix is None:
+ return ''
+
+ user = self.me
# this breaks if the prefix mention is not the bot itself but I
# consider this to be an *incredibly* strange use case. I'd rather go
# for this common use case rather than waste performance for the
@@ -228,7 +274,7 @@ class Context(discord.abc.Messageable):
return pattern.sub("@%s" % user.display_name.replace('\\', r'\\'), self.prefix)
@property
- def cog(self):
+ def cog(self) -> Optional[Cog]:
"""Optional[:class:`.Cog`]: Returns the cog associated with this context's command. None if it does not exist."""
if self.command is None:
@@ -236,38 +282,39 @@ class Context(discord.abc.Messageable):
return self.command.cog
@discord.utils.cached_property
- def guild(self):
+ def guild(self) -> Optional[Guild]:
"""Optional[:class:`.Guild`]: Returns the guild associated with this context's command. None if not available."""
return self.message.guild
@discord.utils.cached_property
- def channel(self):
+ def channel(self) -> MessageableChannel:
"""Union[:class:`.abc.Messageable`]: Returns the channel associated with this context's command.
Shorthand for :attr:`.Message.channel`.
"""
return self.message.channel
@discord.utils.cached_property
- def author(self):
+ def author(self) -> Union[User, Member]:
"""Union[:class:`~discord.User`, :class:`.Member`]:
Returns the author associated with this context's command. Shorthand for :attr:`.Message.author`
"""
return self.message.author
@discord.utils.cached_property
- def me(self):
+ def me(self) -> Union[Member, ClientUser]:
"""Union[:class:`.Member`, :class:`.ClientUser`]:
Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message contexts.
"""
- return self.guild.me if self.guild is not None else self.bot.user
+ # bot.user will never be None at this point.
+ return self.guild.me if self.guild is not None else self.bot.user # type: ignore
@property
- def voice_client(self):
+ def voice_client(self) -> Optional[VoiceProtocol]:
r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable."""
g = self.guild
return g.voice_client if g else None
- async def send_help(self, *args):
+ async def send_help(self, *args: Any) -> Any:
"""send_help(entity=<bot>)
|coro|
@@ -319,12 +366,12 @@ class Context(discord.abc.Messageable):
return None
entity = args[0]
- if entity is None:
- return None
-
if isinstance(entity, str):
entity = bot.get_cog(entity) or bot.get_command(entity)
+ if entity is None:
+ return None
+
try:
entity.qualified_name
except AttributeError:
@@ -348,6 +395,6 @@ class Context(discord.abc.Messageable):
except CommandError as e:
await cmd.on_help_command_error(self, e)
- @discord.utils.copy_doc(discord.Message.reply)
- async def reply(self, content=None, **kwargs):
+ @discord.utils.copy_doc(Message.reply)
+ async def reply(self, content: Optional[str] = None, **kwargs: Any) -> Message:
return await self.message.reply(content, **kwargs)
diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py
index abf88326..88e65507 100644
--- a/discord/ext/commands/core.py
+++ b/discord/ext/commands/core.py
@@ -21,19 +21,29 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
+from __future__ import annotations
from typing import (
Any,
Callable,
Dict,
+ Generator,
+ Generic,
Literal,
+ List,
+ Optional,
Union,
+ Set,
+ Tuple,
+ TypeVar,
+ Type,
+ TYPE_CHECKING,
+ overload,
)
import asyncio
import functools
import inspect
import datetime
-import types
import discord
@@ -42,6 +52,22 @@ from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, Dy
from .converter import run_converters, get_converter, Greedy
from ._types import _BaseCommand
from .cog import Cog
+from .context import Context
+
+
+if TYPE_CHECKING:
+ from typing_extensions import Concatenate, ParamSpec, TypeGuard
+
+ from discord.message import Message
+
+ from ._types import (
+ Coro,
+ CoroFunc,
+ Check,
+ Hook,
+ Error,
+ )
+
__all__ = (
'Command',
@@ -70,6 +96,22 @@ __all__ = (
'bot_has_guild_permissions'
)
+MISSING: Any = discord.utils.MISSING
+
+T = TypeVar('T')
+CogT = TypeVar('CogT', bound='Cog')
+CommandT = TypeVar('CommandT', bound='Command')
+ContextT = TypeVar('ContextT', bound='Context')
+# CHT = TypeVar('CHT', bound='Check')
+GroupT = TypeVar('GroupT', bound='Group')
+HookT = TypeVar('HookT', bound='Hook')
+ErrorT = TypeVar('ErrorT', bound='Error')
+
+if TYPE_CHECKING:
+ P = ParamSpec('P')
+else:
+ P = TypeVar('P')
+
def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]:
partial = functools.partial
while True:
@@ -160,7 +202,7 @@ class _CaseInsensitiveDict(dict):
def __setitem__(self, k, v):
super().__setitem__(k.casefold(), v)
-class Command(_BaseCommand):
+class Command(_BaseCommand, Generic[CogT, P, T]):
r"""A class that implements the protocol for a bot text command.
These are not created manually, instead they are created via the
@@ -172,7 +214,7 @@ class Command(_BaseCommand):
The name of the command.
callback: :ref:`coroutine <coroutine>`
The coroutine that is executed when the command is called.
- help: :class:`str`
+ help: Optional[:class:`str`]
The long help text for the command.
brief: Optional[:class:`str`]
The short help text for the command.
@@ -235,8 +277,9 @@ class Command(_BaseCommand):
.. versionadded:: 2.0
"""
+ __original_kwargs__: Dict[str, Any]
- def __new__(cls, *args, **kwargs):
+ def __new__(cls: Type[CommandT], *args: Any, **kwargs: Any) -> CommandT:
# if you're wondering why this is done, it's because we need to ensure
# we have a complete original copy of **kwargs even for classes that
# mess with it by popping before delegating to the subclass __init__.
@@ -252,16 +295,20 @@ class Command(_BaseCommand):
self.__original_kwargs__ = kwargs.copy()
return self
- def __init__(self, func, **kwargs):
+ def __init__(self, func: Union[
+ Callable[Concatenate[CogT, ContextT, P], Coro[T]],
+ Callable[Concatenate[ContextT, P], Coro[T]],
+ ], **kwargs: Any):
if not asyncio.iscoroutinefunction(func):
raise TypeError('Callback must be a coroutine.')
- self.name = name = kwargs.get('name') or func.__name__
+ name = kwargs.get('name') or func.__name__
if not isinstance(name, str):
raise TypeError('Name of a command must be a string.')
+ self.name: str = name
self.callback = func
- self.enabled = kwargs.get('enabled', True)
+ self.enabled: bool = kwargs.get('enabled', True)
help_doc = kwargs.get('help')
if help_doc is not None:
@@ -271,74 +318,85 @@ class Command(_BaseCommand):
if isinstance(help_doc, bytes):
help_doc = help_doc.decode('utf-8')
- self.help = help_doc
+ self.help: Optional[str] = help_doc
- self.brief = kwargs.get('brief')
- self.usage = kwargs.get('usage')
- self.rest_is_raw = kwargs.get('rest_is_raw', False)
- self.aliases = kwargs.get('aliases', [])
- self.extras = kwargs.get('extras', {})
+ self.brief: Optional[str] = kwargs.get('brief')
+ self.usage: Optional[str] = kwargs.get('usage')
+ self.rest_is_raw: bool = kwargs.get('rest_is_raw', False)
+ self.aliases: Union[List[str], Tuple[str]] = kwargs.get('aliases', [])
+ self.extras: Dict[str, Any] = kwargs.get('extras', {})
if not isinstance(self.aliases, (list, tuple)):
raise TypeError("Aliases of a command must be a list or a tuple of strings.")
- self.description = inspect.cleandoc(kwargs.get('description', ''))
- self.hidden = kwargs.get('hidden', False)
+ self.description: str = inspect.cleandoc(kwargs.get('description', ''))
+ self.hidden: bool = kwargs.get('hidden', False)
try:
checks = func.__commands_checks__
checks.reverse()
except AttributeError:
checks = kwargs.get('checks', [])
- finally:
- self.checks = checks
+
+ self.checks: List[Check] = checks
try:
cooldown = func.__commands_cooldown__
except AttributeError:
cooldown = kwargs.get('cooldown')
- finally:
- if cooldown is None:
- self._buckets = CooldownMapping(cooldown, BucketType.default)
- elif isinstance(cooldown, CooldownMapping):
- self._buckets = cooldown
+
+ if cooldown is None:
+ buckets = CooldownMapping(cooldown, BucketType.default)
+ elif isinstance(cooldown, CooldownMapping):
+ buckets = cooldown
+ else:
+ raise TypeError("Cooldown must be a an instance of CooldownMapping or None.")
+ self._buckets: CooldownMapping = buckets
try:
max_concurrency = func.__commands_max_concurrency__
except AttributeError:
max_concurrency = kwargs.get('max_concurrency')
- finally:
- self._max_concurrency = max_concurrency
- self.require_var_positional = kwargs.get('require_var_positional', False)
- self.ignore_extra = kwargs.get('ignore_extra', True)
- self.cooldown_after_parsing = kwargs.get('cooldown_after_parsing', False)
- self.cog = None
+ self._max_concurrency: Optional[MaxConcurrency] = max_concurrency
+
+ self.require_var_positional: bool = kwargs.get('require_var_positional', False)
+ self.ignore_extra: bool = kwargs.get('ignore_extra', True)
+ self.cooldown_after_parsing: bool = kwargs.get('cooldown_after_parsing', False)
+ self.cog: Optional[CogT] = None
# bandaid for the fact that sometimes parent can be the bot instance
parent = kwargs.get('parent')
- self.parent = parent if isinstance(parent, _BaseCommand) else None
+ self.parent: Optional[GroupMixin] = parent if isinstance(parent, _BaseCommand) else None # type: ignore
+ self._before_invoke: Optional[Hook] = None
try:
before_invoke = func.__before_invoke__
except AttributeError:
- self._before_invoke = None
+ pass
else:
self.before_invoke(before_invoke)
+ self._after_invoke: Optional[Hook] = None
try:
after_invoke = func.__after_invoke__
except AttributeError:
- self._after_invoke = None
+ pass
else:
self.after_invoke(after_invoke)
@property
- def callback(self):
+ def callback(self) -> Union[
+ Callable[Concatenate[CogT, Context, P], Coro[T]],
+ Callable[Concatenate[Context, P], Coro[T]],
+ ]:
return self._callback
@callback.setter
- def callback(self, function):
+ def callback(self, function: Union[
+ Callable[Concatenate[CogT, Context, P], Coro[T]],
+ Callable[Concatenate[Context, P], Coro[T]],
+ ]) -> None:
self._callback = function
unwrap = unwrap_function(function)
self.module = unwrap.__module__
@@ -350,7 +408,7 @@ class Command(_BaseCommand):
self.params = get_signature_parameters(function, globalns)
- def add_check(self, func):
+ def add_check(self, func: Check) -> None:
"""Adds a check to the command.
This is the non-decorator interface to :func:`.check`.
@@ -365,7 +423,7 @@ class Command(_BaseCommand):
self.checks.append(func)
- def remove_check(self, func):
+ def remove_check(self, func: Check) -> None:
"""Removes a check from the command.
This function is idempotent and will not raise an exception
@@ -384,8 +442,8 @@ class Command(_BaseCommand):
except ValueError:
pass
- def update(self, **kwargs):
- """Updates :class:`Command` instance with updated attributes.
+ def update(self, **kwargs: Any) -> None:
+ """Updates :class:`Command` instance with updated attribute.
This works similarly to the :func:`.command` decorator in terms
of parameters in that they are passed to the :class:`Command` or
@@ -393,7 +451,7 @@ class Command(_BaseCommand):
"""
self.__init__(self.callback, **dict(self.__original_kwargs__, **kwargs))
- async def __call__(self, *args, **kwargs):
+ async def __call__(self, context: Context, *args: P.args, **kwargs: P.kwargs) -> T:
"""|coro|
Calls the internal callback that the command holds.
@@ -407,11 +465,11 @@ class Command(_BaseCommand):
.. versionadded:: 1.3
"""
if self.cog is not None:
- return await self.callback(self.cog, *args, **kwargs)
+ return await self.callback(self.cog, context, *args, **kwargs) # type: ignore
else:
- return await self.callback(*args, **kwargs)
+ return await self.callback(context, *args, **kwargs) # type: ignore
- def _ensure_assignment_on_copy(self, other):
+ def _ensure_assignment_on_copy(self, other: CommandT) -> CommandT:
other._before_invoke = self._before_invoke
other._after_invoke = self._after_invoke
if self.checks != other.checks:
@@ -419,7 +477,8 @@ class Command(_BaseCommand):
if self._buckets.valid and not other._buckets.valid:
other._buckets = self._buckets.copy()
if self._max_concurrency != other._max_concurrency:
- other._max_concurrency = self._max_concurrency.copy()
+ # _max_concurrency won't be None at this point
+ other._max_concurrency = self._max_concurrency.copy() # type: ignore
try:
other.on_error = self.on_error
@@ -427,7 +486,7 @@ class Command(_BaseCommand):
pass
return other
- def copy(self):
+ def copy(self: CommandT) -> CommandT:
"""Creates a copy of this command.
Returns
@@ -438,7 +497,7 @@ class Command(_BaseCommand):
ret = self.__class__(self.callback, **self.__original_kwargs__)
return self._ensure_assignment_on_copy(ret)
- def _update_copy(self, kwargs):
+ def _update_copy(self: CommandT, kwargs: Dict[str, Any]) -> CommandT:
if kwargs:
kw = kwargs.copy()
kw.update(self.__original_kwargs__)
@@ -447,7 +506,7 @@ class Command(_BaseCommand):
else:
return self.copy()
- async def dispatch_error(self, ctx, error):
+ async def dispatch_error(self, ctx: Context, error: Exception) -> None:
ctx.command_failed = True
cog = self.cog
try:
@@ -470,7 +529,7 @@ class Command(_BaseCommand):
finally:
ctx.bot.dispatch('command_error', ctx, error)
- async def transform(self, ctx, param):
+ async def transform(self, ctx: Context, param: inspect.Parameter) -> Any:
required = param.default is param.empty
converter = get_converter(param)
consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw
@@ -508,9 +567,10 @@ class Command(_BaseCommand):
argument = view.get_quoted_word()
view.previous = previous
- return await run_converters(ctx, converter, argument, param)
+ # type-checker fails to narrow argument
+ return await run_converters(ctx, converter, argument, param) # type: ignore
- async def _transform_greedy_pos(self, ctx, param, required, converter):
+ async def _transform_greedy_pos(self, ctx: Context, param: inspect.Parameter, required: bool, converter: Any) -> Any:
view = ctx.view
result = []
while not view.eof:
@@ -520,7 +580,7 @@ class Command(_BaseCommand):
view.skip_ws()
try:
argument = view.get_quoted_word()
- value = await run_converters(ctx, converter, argument, param)
+ value = await run_converters(ctx, converter, argument, param) # type: ignore
except (CommandError, ArgumentParsingError):
view.index = previous
break
@@ -531,12 +591,12 @@ class Command(_BaseCommand):
return param.default
return result
- async def _transform_greedy_var_pos(self, ctx, param, converter):
+ async def _transform_greedy_var_pos(self, ctx: Context, param: inspect.Parameter, converter: Any) -> Any:
view = ctx.view
previous = view.index
try:
argument = view.get_quoted_word()
- value = await run_converters(ctx, converter, argument, param)
+ value = await run_converters(ctx, converter, argument, param) # type: ignore
except (CommandError, ArgumentParsingError):
view.index = previous
raise RuntimeError() from None # break loop
@@ -567,7 +627,7 @@ class Command(_BaseCommand):
return result
@property
- def full_parent_name(self):
+ def full_parent_name(self) -> str:
""":class:`str`: Retrieves the fully qualified parent command name.
This the base command name required to execute it. For example,
@@ -575,14 +635,15 @@ class Command(_BaseCommand):
"""
entries = []
command = self
- while command.parent is not None:
- command = command.parent
- entries.append(command.name)
+ # command.parent is type-hinted as GroupMixin some attributes are resolved via MRO
+ while command.parent is not None: # type: ignore
+ command = command.parent # type: ignore
+ entries.append(command.name) # type: ignore
return ' '.join(reversed(entries))
@property
- def parents(self):
+ def parents(self) -> List[Group]:
"""List[:class:`Group`]: Retrieves the parents of this command.
If the command has no parents then it returns an empty :class:`list`.
@@ -593,14 +654,14 @@ class Command(_BaseCommand):
"""
entries = []
command = self
- while command.parent is not None:
- command = command.parent
+ while command.parent is not None: # type: ignore
+ command = command.parent # type: ignore
entries.append(command)
return entries
@property
- def root_parent(self):
+ def root_parent(self) -> Optional[Group]:
"""Optional[:class:`Group`]: Retrieves the root parent of this command.
If the command has no parents then it returns ``None``.
@@ -612,7 +673,7 @@ class Command(_BaseCommand):
return self.parents[-1]
@property
- def qualified_name(self):
+ def qualified_name(self) -> str:
""":class:`str`: Retrieves the fully qualified command name.
This is the full parent name with the command name as well.
@@ -626,10 +687,10 @@ class Command(_BaseCommand):
else:
return self.name
- def __str__(self):
+ def __str__(self) -> str:
return self.qualified_name
- async def _parse_arguments(self, ctx):
+ async def _parse_arguments(self, ctx: Context) -> None:
ctx.args = [ctx] if self.cog is None else [self.cog, ctx]
ctx.kwargs = {}
args = ctx.args
@@ -679,7 +740,7 @@ class Command(_BaseCommand):
if not self.ignore_extra and not view.eof:
raise TooManyArguments('Too many arguments passed to ' + self.qualified_name)
- async def call_before_hooks(self, ctx):
+ async def call_before_hooks(self, ctx: Context) -> None:
# now that we're done preparing we can call the pre-command hooks
# first, call the command local hook:
cog = self.cog
@@ -689,9 +750,9 @@ class Command(_BaseCommand):
# __self__ only exists for methods, not functions
# however, if @command.before_invoke is used, it will be a function
if instance:
- await self._before_invoke(instance, ctx)
+ await self._before_invoke(instance, ctx) # type: ignore
else:
- await self._before_invoke(ctx)
+ await self._before_invoke(ctx) # type: ignore
# call the cog local hook if applicable:
if cog is not None:
@@ -704,14 +765,14 @@ class Command(_BaseCommand):
if hook is not None:
await hook(ctx)
- async def call_after_hooks(self, ctx):
+ async def call_after_hooks(self, ctx: Context) -> None:
cog = self.cog
if self._after_invoke is not None:
instance = getattr(self._after_invoke, '__self__', cog)
if instance:
- await self._after_invoke(instance, ctx)
+ await self._after_invoke(instance, ctx) # type: ignore
else:
- await self._after_invoke(ctx)
+ await self._after_invoke(ctx) # type: ignore
# call the cog local hook if applicable:
if cog is not None:
@@ -723,7 +784,7 @@ class Command(_BaseCommand):
if hook is not None:
await hook(ctx)
- def _prepare_cooldowns(self, ctx):
+ def _prepare_cooldowns(self, ctx: Context) -> None:
if self._buckets.valid:
dt = ctx.message.edited_at or ctx.message.created_at
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
@@ -731,16 +792,17 @@ class Command(_BaseCommand):
if bucket is not None:
retry_after = bucket.update_rate_limit(current)
if retry_after:
- raise CommandOnCooldown(bucket, retry_after, self._buckets.type)
+ raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore
- async def prepare(self, ctx):
+ async def prepare(self, ctx: Context) -> None:
ctx.command = self
if not await self.can_run(ctx):
raise CheckFailure(f'The check functions for command {self.qualified_name} failed.')
if self._max_concurrency is not None:
- await self._max_concurrency.acquire(ctx)
+ # For this application, context can be duck-typed as a Message
+ await self._max_concurrency.acquire(ctx) # type: ignore
try:
if self.cooldown_after_parsing:
@@ -753,10 +815,10 @@ class Command(_BaseCommand):
await self.call_before_hooks(ctx)
except:
if self._max_concurrency is not None:
- await self._max_concurrency.release(ctx)
+ await self._max_concurrency.release(ctx) # type: ignore
raise
- def is_on_cooldown(self, ctx):
+ def is_on_cooldown(self, ctx: Context) -> bool:
"""Checks whether the command is currently on cooldown.
Parameters
@@ -777,7 +839,7 @@ class Command(_BaseCommand):
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
return bucket.get_tokens(current) == 0
- def reset_cooldown(self, ctx):
+ def reset_cooldown(self, ctx: Context) -> None:
"""Resets the cooldown on this command.
Parameters
@@ -789,7 +851,7 @@ class Command(_BaseCommand):
bucket = self._buckets.get_bucket(ctx.message)
bucket.reset()
- def get_cooldown_retry_after(self, ctx):
+ def get_cooldown_retry_after(self, ctx: Context) -> float:
"""Retrieves the amount of seconds before this command can be tried again.
.. versionadded:: 1.4
@@ -813,7 +875,7 @@ class Command(_BaseCommand):
return 0.0
- async def invoke(self, ctx):
+ async def invoke(self, ctx: Context) -> None:
await self.prepare(ctx)
# terminate the invoked_subcommand chain.
@@ -824,7 +886,7 @@ class Command(_BaseCommand):
injected = hooked_wrapped_callback(self, ctx, self.callback)
await injected(*ctx.args, **ctx.kwargs)
- async def reinvoke(self, ctx, *, call_hooks=False):
+ async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None:
ctx.command = self
await self._parse_arguments(ctx)
@@ -833,7 +895,7 @@ class Command(_BaseCommand):
ctx.invoked_subcommand = None
try:
- await self.callback(*ctx.args, **ctx.kwargs)
+ await self.callback(*ctx.args, **ctx.kwargs) # type: ignore
except:
ctx.command_failed = True
raise
@@ -841,7 +903,7 @@ class Command(_BaseCommand):
if call_hooks:
await self.call_after_hooks(ctx)
- def error(self, coro):
+ def error(self, coro: ErrorT) -> ErrorT:
"""A decorator that registers a coroutine as a local error handler.
A local error handler is an :func:`.on_command_error` event limited to
@@ -862,17 +924,17 @@ class Command(_BaseCommand):
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The error handler must be a coroutine.')
- self.on_error = coro
+ self.on_error: Error = coro
return coro
- def has_error_handler(self):
+ def has_error_handler(self) -> bool:
""":class:`bool`: Checks whether the command has an error handler registered.
.. versionadded:: 1.7
"""
return hasattr(self, 'on_error')
- def before_invoke(self, coro):
+ def before_invoke(self, coro: HookT) -> HookT:
"""A decorator that registers a coroutine as a pre-invoke hook.
A pre-invoke hook is called directly before the command is
@@ -899,7 +961,7 @@ class Command(_BaseCommand):
self._before_invoke = coro
return coro
- def after_invoke(self, coro):
+ def after_invoke(self, coro: HookT) -> HookT:
"""A decorator that registers a coroutine as a post-invoke hook.
A post-invoke hook is called directly after the command is
@@ -927,12 +989,12 @@ class Command(_BaseCommand):
return coro
@property
- def cog_name(self):
+ def cog_name(self) -> Optional[str]:
"""Optional[:class:`str`]: The name of the cog this command belongs to, if any."""
return type(self.cog).__cog_name__ if self.cog is not None else None
@property
- def short_doc(self):
+ def short_doc(self) -> str:
""":class:`str`: Gets the "short" documentation of a command.
By default, this is the :attr:`.brief` attribute.
@@ -945,11 +1007,11 @@ class Command(_BaseCommand):
return self.help.split('\n', 1)[0]
return ''
- def _is_typing_optional(self, annotation):
- return getattr(annotation, '__origin__', None) is Union and type(None) in annotation.__args__
+ def _is_typing_optional(self, annotation: Union[T, Optional[T]]) -> TypeGuard[Optional[T]]:
+ return getattr(annotation, '__origin__', None) is Union and type(None) in annotation.__args__ # type: ignore
@property
- def signature(self):
+ def signature(self) -> str:
""":class:`str`: Returns a POSIX-like signature useful for help command output."""
if self.usage is not None:
return self.usage
@@ -1002,7 +1064,7 @@ class Command(_BaseCommand):
return ' '.join(result)
- async def can_run(self, ctx):
+ async def can_run(self, ctx: Context) -> bool:
"""|coro|
Checks if the command can be executed by checking all the predicates
@@ -1052,7 +1114,7 @@ class Command(_BaseCommand):
# since we have no checks, then we just return True.
return True
- return await discord.utils.async_all(predicate(ctx) for predicate in predicates)
+ return await discord.utils.async_all(predicate(ctx) for predicate in predicates) # type: ignore
finally:
ctx.command = original
@@ -1068,24 +1130,24 @@ class GroupMixin:
case_insensitive: :class:`bool`
Whether the commands should be case insensitive. Defaults to ``False``.
"""
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
case_insensitive = kwargs.get('case_insensitive', False)
- self.all_commands = _CaseInsensitiveDict() if case_insensitive else {}
- self.case_insensitive = case_insensitive
+ self.all_commands: Dict[str, Command] = _CaseInsensitiveDict() if case_insensitive else {}
+ self.case_insensitive: bool = case_insensitive
super().__init__(*args, **kwargs)
@property
- def commands(self):
+ def commands(self) -> Set[Command]:
"""Set[:class:`.Command`]: A unique set of commands without aliases that are registered."""
return set(self.all_commands.values())
- def recursively_remove_all_commands(self):
+ def recursively_remove_all_commands(self) -> None:
for command in self.all_commands.copy().values():
if isinstance(command, GroupMixin):
command.recursively_remove_all_commands()
self.remove_command(command.name)
- def add_command(self, command):
+ def add_command(self, command: Command) -> None:
"""Adds a :class:`.Command` into the internal list of commands.
This is usually not called, instead the :meth:`~.GroupMixin.command` or
@@ -1123,7 +1185,7 @@ class GroupMixin:
raise CommandRegistrationError(alias, alias_conflict=True)
self.all_commands[alias] = command
- def remove_command(self, name):
+ def remove_command(self, name: str) -> Optional[Command]:
"""Remove a :class:`.Command` from the internal list
of commands.
@@ -1156,11 +1218,11 @@ class GroupMixin:
# in the case of a CommandRegistrationError, an alias might conflict
# with an already existing command. If this is the case, we want to
# make sure the pre-existing command is not removed.
- if cmd not in (None, command):
+ if cmd is not None and cmd != command:
self.all_commands[alias] = cmd
return command
- def walk_commands(self):
+ def walk_commands(self) -> Generator[Command, None, None]:
"""An iterator that recursively walks through all commands and subcommands.
.. versionchanged:: 1.4
@@ -1176,7 +1238,7 @@ class GroupMixin:
if isinstance(command, GroupMixin):
yield from command.walk_commands()
- def get_command(self, name):
+ def get_command(self, name: str) -> Optional[Command]:
"""Get a :class:`.Command` from the internal list
of commands.
@@ -1210,13 +1272,39 @@ class GroupMixin:
for name in names[1:]:
try:
- obj = obj.all_commands[name]
+ obj = obj.all_commands[name] # type: ignore
except (AttributeError, KeyError):
return None
return obj
- def command(self, *args, **kwargs):
+ @overload
+ def command(
+ self,
+ name: str = ...,
+ cls: Type[Command[CogT, P, T]] = ...,
+ *args: Any,
+ **kwargs: Any,
+ ) -> Callable[[Callable[Concatenate[ContextT, P], Coro[T]]], Command[CogT, P, T]]:
+ ...
+
+ @overload
+ def command(
+ self,
+ name: str = ...,
+ cls: Type[CommandT] = ...,
+ *args: Any,
+ **kwargs: Any,
+ ) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], CommandT]:
+ ...
+
+ def command(
+ self,
+ name: str = MISSING,
+ cls: Type[CommandT] = MISSING,
+ *args: Any,
+ **kwargs: Any,
+ ) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], CommandT]:
"""A shortcut decorator that invokes :func:`.command` and adds it to
the internal command list via :meth:`~.GroupMixin.add_command`.
@@ -1225,15 +1313,41 @@ class GroupMixin:
Callable[..., :class:`Command`]
A decorator that converts the provided method into a Command, adds it to the bot, then returns it.
"""
- def decorator(func):
+ def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> CommandT:
kwargs.setdefault('parent', self)
- result = command(*args, **kwargs)(func)
+ result = command(name=name, cls=cls, *args, **kwargs)(func)
self.add_command(result)
return result
return decorator
- def group(self, *args, **kwargs):
+ @overload
+ def group(
+ self,
+ name: str = ...,
+ cls: Type[Group[CogT, P, T]] = ...,
+ *args: Any,
+ **kwargs: Any,
+ ) -> Callable[[Callable[Concatenate[ContextT, P], Coro[T]]], Group[CogT, P, T]]:
+ ...
+
+ @overload
+ def group(
+ self,
+ name: str = ...,
+ cls: Type[GroupT] = ...,
+ *args: Any,
+ **kwargs: Any,
+ ) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], GroupT]:
+ ...
+
+ def group(
+ self,
+ name: str = MISSING,
+ cls: Type[GroupT] = MISSING,
+ *args: Any,
+ **kwargs: Any,
+ ) -> Callable[[Callable[Concatenate[ContextT, P], Coro[Any]]], GroupT]:
"""A shortcut decorator that invokes :func:`.group` and adds it to
the internal command list via :meth:`~.GroupMixin.add_command`.
@@ -1242,15 +1356,15 @@ class GroupMixin:
Callable[..., :class:`Group`]
A decorator that converts the provided method into a Group, adds it to the bot, then returns it.
"""
- def decorator(func):
+ def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> GroupT:
kwargs.setdefault('parent', self)
- result = group(*args, **kwargs)(func)
+ result = group(name=name, cls=cls, *args, **kwargs)(func)
self.add_command(result)
return result
return decorator
-class Group(GroupMixin, Command):
+class Group(GroupMixin, Command[CogT, P, T]):
"""A class that implements a grouping protocol for commands to be
executed as subcommands.
@@ -1272,11 +1386,11 @@ class Group(GroupMixin, Command):
Indicates if the group's commands should be case insensitive.
Defaults to ``False``.
"""
- def __init__(self, *args, **attrs):
- self.invoke_without_command = attrs.pop('invoke_without_command', False)
+ def __init__(self, *args: Any, **attrs: Any) -> None:
+ self.invoke_without_command: bool = attrs.pop('invoke_without_command', False)
super().__init__(*args, **attrs)
- def copy(self):
+ def copy(self: GroupT) -> GroupT:
"""Creates a copy of this :class:`Group`.
Returns
@@ -1287,9 +1401,9 @@ class Group(GroupMixin, Command):
ret = super().copy()
for cmd in self.commands:
ret.add_command(cmd.copy())
- return ret
+ return ret # type: ignore
- async def invoke(self, ctx):
+ async def invoke(self, ctx: Context) -> None:
ctx.invoked_subcommand = None
ctx.subcommand_passed = None
early_invoke = not self.invoke_without_command
@@ -1309,7 +1423,7 @@ class Group(GroupMixin, Command):
injected = hooked_wrapped_callback(self, ctx, self.callback)
await injected(*ctx.args, **ctx.kwargs)
- ctx.invoked_parents.append(ctx.invoked_with)
+ ctx.invoked_parents.append(ctx.invoked_with) # type: ignore
if trigger and ctx.invoked_subcommand:
ctx.invoked_with = trigger
@@ -1320,7 +1434,7 @@ class Group(GroupMixin, Command):
view.previous = previous
await super().invoke(ctx)
- async def reinvoke(self, ctx, *, call_hooks=False):
+ async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None:
ctx.invoked_subcommand = None
early_invoke = not self.invoke_without_command
if early_invoke:
@@ -1341,7 +1455,7 @@ class Group(GroupMixin, Command):
if early_invoke:
try:
- await self.callback(*ctx.args, **ctx.kwargs)
+ await self.callback(*ctx.args, **ctx.kwargs) # type: ignore
except:
ctx.command_failed = True
raise
@@ -1349,7 +1463,7 @@ class Group(GroupMixin, Command):
if call_hooks:
await self.call_after_hooks(ctx)
- ctx.invoked_parents.append(ctx.invoked_with)
+ ctx.invoked_parents.append(ctx.invoked_with) # type: ignore
if trigger and ctx.invoked_subcommand:
ctx.invoked_with = trigger
@@ -1362,7 +1476,48 @@ class Group(GroupMixin, Command):
# Decorators
-def command(name=None, cls=None, **attrs):
+@overload
+def command(
+ name: str = ...,
+ cls: Type[Command[CogT, P, T]] = ...,
+ **attrs: Any,
+) -> Callable[
+ [
+ Union[
+ Callable[Concatenate[CogT, ContextT, P], Coro[T]],
+ Callable[Concatenate[ContextT, P], Coro[T]],
+ ]
+ ]
+, Command[CogT, P, T]]:
+ ...
+
+@overload
+def command(
+ name: str = ...,
+ cls: Type[CommandT] = ...,
+ **attrs: Any,
+) -> Callable[
+ [
+ Union[
+ Callable[Concatenate[CogT, ContextT, P], Coro[Any]],
+ Callable[Concatenate[ContextT, P], Coro[Any]],
+ ]
+ ]
+, CommandT]:
+ ...
+
+def command(
+ name: str = MISSING,
+ cls: Type[CommandT] = MISSING,
+ **attrs: Any
+) -> Callable[
+ [
+ Union[
+ Callable[Concatenate[ContextT, P], Coro[Any]],
+ Callable[Concatenate[CogT, ContextT, P], Coro[T]],
+ ]
+ ]
+, Union[Command[CogT, P, T], CommandT]]:
"""A decorator that transforms a function into a :class:`.Command`
or if called with :func:`.group`, :class:`.Group`.
@@ -1392,17 +1547,61 @@ def command(name=None, cls=None, **attrs):
TypeError
If the function is not a coroutine or is already a command.
"""
- if cls is None:
- cls = Command
+ if cls is MISSING:
+ cls = Command # type: ignore
- def decorator(func):
+ def decorator(func: Union[
+ Callable[Concatenate[ContextT, P], Coro[Any]],
+ Callable[Concatenate[CogT, ContextT, P], Coro[Any]],
+ ]) -> CommandT:
if isinstance(func, Command):
raise TypeError('Callback is already a command.')
return cls(func, name=name, **attrs)
return decorator
-def group(name=None, **attrs):
+@overload
+def group(
+ name: str = ...,
+ cls: Type[Group[CogT, P, T]] = ...,
+ **attrs: Any,
+) -> Callable[
+ [
+ Union[
+ Callable[Concatenate[CogT, ContextT, P], Coro[T]],
+ Callable[Concatenate[ContextT, P], Coro[T]],
+ ]
+ ]
+, Group[CogT, P, T]]:
+ ...
+
+@overload
+def group(
+ name: str = ...,
+ cls: Type[GroupT] = ...,
+ **attrs: Any,
+) -> Callable[
+ [
+ Union[
+ Callable[Concatenate[CogT, ContextT, P], Coro[Any]],
+ Callable[Concatenate[ContextT, P], Coro[Any]],
+ ]
+ ]
+, GroupT]:
+ ...
+
+def group(
+ name: str = MISSING,
+ cls: Type[GroupT] = MISSING,
+ **attrs: Any,
+) -> Callable[
+ [
+ Union[
+ Callable[Concatenate[ContextT, P], Coro[Any]],
+ Callable[Concatenate[CogT, ContextT, P], Coro[T]],
+ ]
+ ]
+, Union[Group[CogT, P, T], GroupT]]:
"""A decorator that transforms a function into a :class:`.Group`.
This is similar to the :func:`.command` decorator but the ``cls``
@@ -1411,11 +1610,11 @@ def group(name=None, **attrs):
.. versionchanged:: 1.1
The ``cls`` parameter can now be passed.
"""
+ if cls is MISSING:
+ cls = Group # type: ignore
+ return command(name=name, cls=cls, **attrs) # type: ignore
- attrs.setdefault('cls', Group)
- return command(name=name, **attrs)
-
-def check(predicate):
+def check(predicate: Check) -> Callable[[T], T]:
r"""A decorator that adds a check to the :class:`.Command` or its
subclasses. These checks could be accessed via :attr:`.Command.checks`.
@@ -1486,7 +1685,7 @@ def check(predicate):
The predicate to check if the command should be invoked.
"""
- def decorator(func):
+ def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
if isinstance(func, Command):
func.checks.append(predicate)
else:
@@ -1502,12 +1701,12 @@ def check(predicate):
else:
@functools.wraps(predicate)
async def wrapper(ctx):
- return predicate(ctx)
+ return predicate(ctx) # type: ignore
decorator.predicate = wrapper
- return decorator
+ return decorator # type: ignore
-def check_any(*checks):
+def check_any(*checks: Check) -> Callable[[T], T]:
r"""A :func:`check` that is added that checks if any of the checks passed
will pass, i.e. using logical OR.
@@ -1560,7 +1759,7 @@ def check_any(*checks):
else:
unwrapped.append(pred)
- async def predicate(ctx):
+ async def predicate(ctx: Context) -> bool:
errors = []
for func in unwrapped:
try:
@@ -1575,7 +1774,7 @@ def check_any(*checks):
return check(predicate)
-def has_role(item):
+def has_role(item: Union[int, str]) -> Callable[[T], T]:
"""A :func:`.check` that is added that checks if the member invoking the
command has the role specified via the name or ID specified.
@@ -1602,21 +1801,22 @@ def has_role(item):
The name or ID of the role to check.
"""
- def predicate(ctx):
+ def predicate(ctx: Context) -> bool:
if ctx.guild is None:
raise NoPrivateMessage()
+ # ctx.guild is None doesn't narrow ctx.author to Member
if isinstance(item, int):
- role = discord.utils.get(ctx.author.roles, id=item)
+ role = discord.utils.get(ctx.author.roles, id=item) # type: ignore
else:
- role = discord.utils.get(ctx.author.roles, name=item)
+ role = discord.utils.get(ctx.author.roles, name=item) # type: ignore
if role is None:
raise MissingRole(item)
return True
return check(predicate)
-def has_any_role(*items):
+def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
r"""A :func:`.check` that is added that checks if the member invoking the
command has **any** of the roles specified. This means that if they have
one out of the three roles specified, then this check will return `True`.
@@ -1651,14 +1851,15 @@ def has_any_role(*items):
if ctx.guild is None:
raise NoPrivateMessage()
- getter = functools.partial(discord.utils.get, ctx.author.roles)
+ # ctx.guild is None doesn't narrow ctx.author to Member
+ getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore
if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items):
return True
- raise MissingAnyRole(items)
+ raise MissingAnyRole(list(items))
return check(predicate)
-def bot_has_role(item):
+def bot_has_role(item: int) -> Callable[[T], T]:
"""Similar to :func:`.has_role` except checks if the bot itself has the
role.
@@ -1686,7 +1887,7 @@ def bot_has_role(item):
return True
return check(predicate)
-def bot_has_any_role(*items):
+def bot_has_any_role(*items: int) -> Callable[[T], T]:
"""Similar to :func:`.has_any_role` except checks if the bot itself has
any of the roles listed.
@@ -1707,10 +1908,10 @@ def bot_has_any_role(*items):
getter = functools.partial(discord.utils.get, me.roles)
if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items):
return True
- raise BotMissingAnyRole(items)
+ raise BotMissingAnyRole(list(items))
return check(predicate)
-def has_permissions(**perms):
+def has_permissions(**perms: bool) -> Callable[[T], T]:
"""A :func:`.check` that is added that checks if the member has all of
the permissions necessary.
@@ -1744,9 +1945,9 @@ def has_permissions(**perms):
if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
- def predicate(ctx):
+ def predicate(ctx: Context) -> bool:
ch = ctx.channel
- permissions = ch.permissions_for(ctx.author)
+ permissions = ch.permissions_for(ctx.author) # type: ignore
missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value]
@@ -1757,7 +1958,7 @@ def has_permissions(**perms):
return check(predicate)
-def bot_has_permissions(**perms):
+def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
"""Similar to :func:`.has_permissions` except checks if the bot itself has
the permissions listed.
@@ -1769,10 +1970,10 @@ def bot_has_permissions(**perms):
if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
- def predicate(ctx):
+ def predicate(ctx: Context) -> bool:
guild = ctx.guild
me = guild.me if guild is not None else ctx.bot.user
- permissions = ctx.channel.permissions_for(me)
+ permissions = ctx.channel.permissions_for(me) # type: ignore
missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value]
@@ -1783,7 +1984,7 @@ def bot_has_permissions(**perms):
return check(predicate)
-def has_guild_permissions(**perms):
+def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
"""Similar to :func:`.has_permissions`, but operates on guild wide
permissions instead of the current channel permissions.
@@ -1797,11 +1998,11 @@ def has_guild_permissions(**perms):
if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
- def predicate(ctx):
+ def predicate(ctx: Context) -> bool:
if not ctx.guild:
raise NoPrivateMessage
- permissions = ctx.author.guild_permissions
+ permissions = ctx.author.guild_permissions # type: ignore
missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value]
if not missing:
@@ -1811,7 +2012,7 @@ def has_guild_permissions(**perms):
return check(predicate)
-def bot_has_guild_permissions(**perms):
+def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
"""Similar to :func:`.has_guild_permissions`, but checks the bot
members guild permissions.
@@ -1822,11 +2023,11 @@ def bot_has_guild_permissions(**perms):
if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
- def predicate(ctx):
+ def predicate(ctx: Context) -> bool:
if not ctx.guild:
raise NoPrivateMessage
- permissions = ctx.me.guild_permissions
+ permissions = ctx.me.guild_permissions # type: ignore
missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value]
if not missing:
@@ -1836,7 +2037,7 @@ def bot_has_guild_permissions(**perms):
return check(predicate)
-def dm_only():
+def dm_only() -> Callable[[T], T]:
"""A :func:`.check` that indicates this command must only be used in a
DM context. Only private messages are allowed when
using the command.
@@ -1847,14 +2048,14 @@ def dm_only():
.. versionadded:: 1.1
"""
- def predicate(ctx):
+ def predicate(ctx: Context) -> bool:
if ctx.guild is not None:
raise PrivateMessageOnly()
return True
return check(predicate)
-def guild_only():
+def guild_only() -> Callable[[T], T]:
"""A :func:`.check` that indicates this command must only be used in a
guild context only. Basically, no private messages are allowed when
using the command.
@@ -1863,14 +2064,14 @@ def guild_only():
that is inherited from :exc:`.CheckFailure`.
"""
- def predicate(ctx):
+ def predicate(ctx: Context) -> bool:
if ctx.guild is None:
raise NoPrivateMessage()
return True
return check(predicate)
-def is_owner():
+def is_owner() -> Callable[[T], T]:
"""A :func:`.check` that checks if the person invoking this command is the
owner of the bot.
@@ -1880,14 +2081,14 @@ def is_owner():
from :exc:`.CheckFailure`.
"""
- async def predicate(ctx):
+ async def predicate(ctx: Context) -> bool:
if not await ctx.bot.is_owner(ctx.author):
raise NotOwner('You do not own this bot.')
return True
return check(predicate)
-def is_nsfw():
+def is_nsfw() -> Callable[[T], T]:
"""A :func:`.check` that checks if the channel is a NSFW channel.
This check raises a special exception, :exc:`.NSFWChannelRequired`
@@ -1898,14 +2099,14 @@ def is_nsfw():
Raise :exc:`.NSFWChannelRequired` instead of generic :exc:`.CheckFailure`.
DM channels will also now pass this check.
"""
- def pred(ctx):
+ def pred(ctx: Context) -> bool:
ch = ctx.channel
if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()):
return True
- raise NSFWChannelRequired(ch)
+ raise NSFWChannelRequired(ch) # type: ignore
return check(pred)
-def cooldown(rate, per, type=BucketType.default):
+def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[Message], Any]] = BucketType.default) -> Callable[[T], T]:
"""A decorator that adds a cooldown to a :class:`.Command`
A cooldown allows a command to only be used a specific amount
@@ -1932,15 +2133,15 @@ def cooldown(rate, per, type=BucketType.default):
Callables are now supported for custom bucket types.
"""
- def decorator(func):
+ def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
if isinstance(func, Command):
func._buckets = CooldownMapping(Cooldown(rate, per), type)
else:
func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type)
return func
- return decorator
+ return decorator # type: ignore
-def dynamic_cooldown(cooldown, type=BucketType.default):
+def dynamic_cooldown(cooldown: Union[BucketType, Callable[[Message], Any]], type: BucketType = BucketType.default) -> Callable[[T], T]:
"""A decorator that adds a dynamic cooldown to a :class:`.Command`
This differs from :func:`.cooldown` in that it takes a function that
@@ -1972,15 +2173,15 @@ def dynamic_cooldown(cooldown, type=BucketType.default):
if not callable(cooldown):
raise TypeError("A callable must be provided")
- def decorator(func):
+ def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
if isinstance(func, Command):
func._buckets = DynamicCooldownMapping(cooldown, type)
else:
func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type)
return func
- return decorator
+ return decorator # type: ignore
-def max_concurrency(number, per=BucketType.default, *, wait=False):
+def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: bool = False) -> Callable[[T], T]:
"""A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses.
This enables you to only allow a certain number of command invocations at the same time,
@@ -2004,16 +2205,16 @@ def max_concurrency(number, per=BucketType.default, *, wait=False):
then the command waits until it can be executed.
"""
- def decorator(func):
+ def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
value = MaxConcurrency(number, per=per, wait=wait)
if isinstance(func, Command):
func._max_concurrency = value
else:
func.__commands_max_concurrency__ = value
return func
- return decorator
+ return decorator # type: ignore
-def before_invoke(coro):
+def before_invoke(coro) -> Callable[[T], T]:
"""A decorator that registers a coroutine as a pre-invoke hook.
This allows you to refer to one before invoke hook for several commands that
@@ -2051,15 +2252,15 @@ def before_invoke(coro):
bot.add_cog(What())
"""
- def decorator(func):
+ def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
if isinstance(func, Command):
func.before_invoke(coro)
else:
func.__before_invoke__ = coro
return func
- return decorator
+ return decorator # type: ignore
-def after_invoke(coro):
+def after_invoke(coro) -> Callable[[T], T]:
"""A decorator that registers a coroutine as a post-invoke hook.
This allows you to refer to one after invoke hook for several commands that
@@ -2067,10 +2268,10 @@ def after_invoke(coro):
.. versionadded:: 1.4
"""
- def decorator(func):
+ def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
if isinstance(func, Command):
func.after_invoke(coro)
else:
func.__after_invoke__ = coro
return func
- return decorator
+ return decorator # type: ignore
diff --git a/discord/ext/commands/help.py b/discord/ext/commands/help.py
index 6de81bb5..6a70726d 100644
--- a/discord/ext/commands/help.py
+++ b/discord/ext/commands/help.py
@@ -27,11 +27,17 @@ import copy
import functools
import inspect
import re
+
+from typing import Optional, TYPE_CHECKING
+
import discord.utils
from .core import Group, Command
from .errors import CommandError
+if TYPE_CHECKING:
+ from .context import Context
+
__all__ = (
'Paginator',
'HelpCommand',
@@ -320,7 +326,7 @@ class HelpCommand:
self.command_attrs = attrs = options.pop('command_attrs', {})
attrs.setdefault('name', 'help')
attrs.setdefault('help', 'Shows this message')
- self.context = None
+ self.context: Optional[Context] = None
self._command_impl = _HelpCommandImpl(self, **self.command_attrs)
def copy(self):