aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRapptz <[email protected]>2016-06-19 22:06:09 -0400
committerRapptz <[email protected]>2016-06-19 22:15:11 -0400
commitadbf2c720f192d20d6bd71bac55d1c5057a8baa1 (patch)
treeea26dfd1c340fda9a71ad4052d9678c7026d9207
parent[commands] Add `delete_after` keyword argument to utility functions. (diff)
downloaddiscord.py-adbf2c720f192d20d6bd71bac55d1c5057a8baa1.tar.xz
discord.py-adbf2c720f192d20d6bd71bac55d1c5057a8baa1.zip
[commands] Add the concept of global checks.
Global checks are checks that are executed before regular per-command checks except done to every command that the bot has registered. This allows you to have checks that apply to every command without having to override `on_message` or appending the check to every single command.
-rw-r--r--discord/ext/commands/bot.py87
-rw-r--r--discord/ext/commands/core.py5
2 files changed, 90 insertions, 2 deletions
diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py
index 4cbc8247..44a9f352 100644
--- a/discord/ext/commands/bot.py
+++ b/discord/ext/commands/bot.py
@@ -208,6 +208,7 @@ class Bot(GroupMixin, discord.Client):
self.extra_events = {}
self.cogs = {}
self.extensions = {}
+ self._checks = []
self.description = inspect.cleandoc(description) if description else ''
self.pm_help = pm_help
self.command_not_found = options.pop('command_not_found', 'No command called "{}" found.')
@@ -443,6 +444,70 @@ class Bot(GroupMixin, discord.Client):
destination = _get_variable('_internal_channel')
return self.send_typing(destination)
+ # global check registration
+
+ def check(self):
+ """A decorator that adds a global check to the bot.
+
+ A global check is similar to a :func:`check` that is applied
+ on a per command basis except it is run before any command checks
+ have been verified and applies to every command the bot has.
+
+ .. warning::
+
+ This function must be a *regular* function and not a coroutine.
+
+ Similar to a command :func:`check`\, this takes a single parameter
+ of type :class:`Context` and can only raise exceptions derived from
+ :exc:`CommandError`.
+
+ Example
+ ---------
+
+ .. code-block:: python
+
+ @bot.check
+ def whitelist(ctx):
+ return ctx.message.author.id in my_whitelist
+
+ """
+ def decorator(func):
+ self.add_check(func)
+ return func
+ return decorator
+
+ def add_check(self, func):
+ """Adds a global check to the bot.
+
+ This is the non-decorator interface to :meth:`check`.
+
+ Parameters
+ -----------
+ func
+ The function that was used as a global check.
+ """
+ self._checks.append(func)
+
+ def remove_check(self, func):
+ """Removes a global check from the bot.
+
+ This function is idempotent and will not raise an exception
+ if the function is not in the global checks.
+
+ Parameters
+ -----------
+ func
+ The function to remove from the global checks.
+ """
+
+ try:
+ self._checks.remove(func)
+ except ValueError:
+ pass
+
+ def can_run(self, ctx):
+ return all(f(ctx) for f in self._checks)
+
# listener registration
def add_listener(self, func, name=None):
@@ -543,6 +608,9 @@ class Bot(GroupMixin, discord.Client):
They are meant as a way to organize multiple relevant commands
into a singular class that shares some state or no state at all.
+ The cog can also have a ``__check`` member function that allows
+ you to define a global check. See :meth:`check` for more info.
+
More information will be documented soon.
Parameters
@@ -552,6 +620,14 @@ class Bot(GroupMixin, discord.Client):
"""
self.cogs[type(cog).__name__] = cog
+
+ try:
+ check = getattr(cog, '_{.__class__.__name__}__check'.format(cog))
+ except AttributeError:
+ pass
+ else:
+ self.add_check(check)
+
members = inspect.getmembers(cog)
for name, member in members:
# register commands the cog has
@@ -613,11 +689,20 @@ class Bot(GroupMixin, discord.Client):
if name.startswith('on_'):
self.remove_listener(member)
+ try:
+ check = getattr(cog, '_{0.__class__.__name__}__check'.format(cog))
+ except AttributeError:
+ pass
+ else:
+ self.remove_check(check)
+
unloader_name = '_{0.__class__.__name__}__unload'.format(cog)
try:
- getattr(cog, unloader_name)()
+ unloader = getattr(cog, unloader_name)
except AttributeError:
pass
+ else:
+ unloader()
del cog
diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py
index ac91933d..330f2606 100644
--- a/discord/ext/commands/core.py
+++ b/discord/ext/commands/core.py
@@ -395,8 +395,11 @@ class Command:
if self.no_pm and ctx.message.channel.is_private:
raise NoPrivateMessage('This command cannot be used in private messages.')
+ if not ctx.bot.can_run(ctx):
+ raise CheckFailure('The global check functions for command {0.qualified_name} failed.'.format(self))
+
if not self.can_run(ctx):
- raise CheckFailure('The check functions for command {0.name} failed.'.format(self))
+ raise CheckFailure('The check functions for command {0.qualified_name} failed.'.format(self))
@asyncio.coroutine
def invoke(self, ctx):