aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRapptz <[email protected]>2017-01-27 18:53:21 -0500
committerRapptz <[email protected]>2017-01-27 18:53:21 -0500
commit1c49374210b3f22eb044fd4b19ae9ae82bf4a45a (patch)
treec95f03232ce8a3cb9fc7d5cc1147838d19429b68
parent[commands] Add Context.command_failed attribute. (diff)
downloaddiscord.py-1c49374210b3f22eb044fd4b19ae9ae82bf4a45a.tar.xz
discord.py-1c49374210b3f22eb044fd4b19ae9ae82bf4a45a.zip
[commands] Implement before and after invoke command hooks.
Fixes #464.
-rw-r--r--discord/ext/commands/bot.py67
-rw-r--r--discord/ext/commands/core.py121
2 files changed, 186 insertions, 2 deletions
diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py
index 1913a319..d080d64a 100644
--- a/discord/ext/commands/bot.py
+++ b/discord/ext/commands/bot.py
@@ -136,6 +136,8 @@ class BotBase(GroupMixin):
self.cogs = {}
self.extensions = {}
self._checks = []
+ self._before_invoke = None
+ self._after_invoke = None
self.description = inspect.cleandoc(description) if description else ''
self.pm_help = pm_help
self.command_not_found = options.pop('command_not_found', 'No command called "{}" found.')
@@ -269,6 +271,71 @@ class BotBase(GroupMixin):
def can_run(self, ctx):
return all(f(ctx) for f in self._checks)
+ def before_invoke(self, coro):
+ """A decorator that registers a coroutine as a pre-invoke hook.
+
+ A pre-invoke hook is called directly before the command is
+ called. This makes it a useful function to set up database
+ connections or any type of set up required.
+
+ This pre-invoke hook takes a sole parameter, a :class:`Context`.
+
+ .. note::
+
+ The :meth:`before_invoke` and :meth:`after_invoke` hooks are
+ only called if all checks and argument parsing procedures pass
+ without error. If any check or argument parsing procedures fail
+ then the hooks are not called.
+
+ Parameters
+ -----------
+ coro
+ The coroutine to register as the pre-invoke hook.
+
+ Raises
+ -------
+ discord.ClientException
+ The coroutine is not actually a coroutine.
+ """
+ if not asyncio.iscoroutinefunction(coro):
+ raise discord.ClientException('The error handler must be a coroutine.')
+
+ self._before_invoke = coro
+ return coro
+
+ def after_invoke(self, coro):
+ """A decorator that registers a coroutine as a post-invoke hook.
+
+ A post-invoke hook is called directly after the command is
+ called. This makes it a useful function to clean-up database
+ connections or any type of clean up required.
+
+ This post-invoke hook takes a sole parameter, a :class:`Context`.
+
+ .. note::
+
+ Similar to :meth:`before_invoke`\, this is not called unless
+ checks and argument parsing procedures succeed. This hook is,
+ however, **always** called regardless of the internal command
+ callback raising an error (i.e. :exc:`CommandInvokeError`\).
+ This makes it ideal for clean-up scenarios.
+
+ Parameters
+ -----------
+ coro
+ The coroutine to register as the post-invoke hook.
+
+ Raises
+ -------
+ discord.ClientException
+ The coroutine is not actually a coroutine.
+ """
+ if not asyncio.iscoroutinefunction(coro):
+ raise discord.ClientException('The error handler must be a coroutine.')
+
+ self._after_invoke = coro
+ return coro
+
# listener registration
def add_listener(self, func, name=None):
diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py
index 51a13f94..d1a91133 100644
--- a/discord/ext/commands/core.py
+++ b/discord/ext/commands/core.py
@@ -52,6 +52,21 @@ def wrap_callback(coro):
return ret
return wrapped
+def hooked_wrapped_callback(command, ctx, coro):
+ @functools.wraps(coro)
+ @asyncio.coroutine
+ def wrapped(*args, **kwargs):
+ try:
+ ret = yield from coro(*args, **kwargs)
+ except CommandError:
+ raise
+ except Exception as e:
+ raise CommandInvokeError(e) from e
+ finally:
+ yield from command.call_after_hooks(ctx)
+ return ret
+ return wrapped
+
def _convert_to_bool(argument):
lowered = argument.lower()
if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'):
@@ -144,6 +159,8 @@ class Command:
self.instance = None
self.parent = None
self._buckets = CooldownMapping(kwargs.get('cooldown'))
+ self._before_invoke = None
+ self._after_invoke = None
@asyncio.coroutine
def dispatch_error(self, error, ctx):
@@ -336,6 +353,50 @@ class Command:
raise CheckFailure('The check functions for command {0.qualified_name} failed.'.format(self))
@asyncio.coroutine
+ def call_before_hooks(self, ctx):
+ # now that we're done preparing we can call the pre-command hooks
+ # first, call the command local hook:
+ cog = self.instance
+ if self._before_invoke is not None:
+ if cog is None:
+ yield from self._before_invoke(ctx)
+ else:
+ yield from self._before_invoke(cog, ctx)
+
+ # call the cog local hook if applicable:
+ try:
+ hook = getattr(cog, '_{0.__class__.__name__}__before_invoke'.format(cog))
+ except AttributeError:
+ pass
+ else:
+ yield from hook(ctx)
+
+ # call the bot global hook if necessary
+ hook = ctx.bot._before_invoke
+ if hook is not None:
+ yield from hook(ctx)
+
+ @asyncio.coroutine
+ def call_after_hooks(self, ctx):
+ cog = self.instance
+ if self._after_invoke is not None:
+ if cog is None:
+ yield from self._after_invoke(ctx)
+ else:
+ yield from self._after_invoke(cog, ctx)
+
+ try:
+ hook = getattr(cog, '_{0.__class__.__name__}__after_invoke'.format(cog))
+ except AttributeError:
+ pass
+ else:
+ yield from hook(ctx)
+
+ hook = ctx.bot._after_invoke
+ if hook is not None:
+ yield from hook(ctx)
+
+ @asyncio.coroutine
def prepare(self, ctx):
ctx.command = self
self._verify_checks(ctx)
@@ -347,6 +408,8 @@ class Command:
if retry_after:
raise CommandOnCooldown(bucket, retry_after)
+ yield from self.call_before_hooks(ctx)
+
def reset_cooldown(self, ctx):
"""Resets the cooldown on this command.
@@ -367,7 +430,7 @@ class Command:
# since we're in a regular command (and not a group) then
# the invoked subcommand is None.
ctx.invoked_subcommand = None
- injected = wrap_callback(self.callback)
+ injected = hooked_wrapped_callback(self, ctx, self.callback)
yield from injected(*ctx.args, **ctx.kwargs)
def error(self, coro):
@@ -394,6 +457,60 @@ class Command:
self.on_error = coro
return coro
+ def before_invoke(self, coro):
+ """A decorator that registers a coroutine as a pre-invoke hook.
+
+ A pre-invoke hook is called directly before :meth:`invoke` is
+ called. This makes it a useful function to set up database
+ connections or any type of set up required.
+
+ This pre-invoke hook takes a sole parameter, a :class:`Context`.
+
+ See :meth:`Bot.before_invoke` for more info.
+
+ Parameters
+ -----------
+ coro
+ The coroutine to register as the pre-invoke hook.
+
+ Raises
+ -------
+ discord.ClientException
+ The coroutine is not actually a coroutine.
+ """
+ if not asyncio.iscoroutinefunction(coro):
+ raise discord.ClientException('The error handler must be a coroutine.')
+
+ self._before_invoke = coro
+ return coro
+
+ def after_invoke(self, coro):
+ """A decorator that registers a coroutine as a post-invoke hook.
+
+ A post-invoke hook is called directly after :meth:`invoke` is
+ called. This makes it a useful function to clean-up database
+ connections or any type of clean up required.
+
+ This post-invoke hook takes a sole parameter, a :class:`Context`.
+
+ See :meth:`Bot.after_invoke` for more info.
+
+ Parameters
+ -----------
+ coro
+ The coroutine to register as the post-invoke hook.
+
+ Raises
+ -------
+ discord.ClientException
+ The coroutine is not actually a coroutine.
+ """
+ if not asyncio.iscoroutinefunction(coro):
+ raise discord.ClientException('The error handler must be a coroutine.')
+
+ self._after_invoke = coro
+ return coro
+
@property
def cog_name(self):
"""The name of the cog this command belongs to. None otherwise."""
@@ -610,7 +727,7 @@ class Group(GroupMixin, Command):
ctx.invoked_subcommand = self.commands.get(trigger, None)
if early_invoke:
- injected = wrap_callback(self.callback)
+ injected = hooked_wrapped_callback(self, ctx, self.callback)
yield from injected(*ctx.args, **ctx.kwargs)
if trigger and ctx.invoked_subcommand: