aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--discord/ext/commands/context.py53
-rw-r--r--discord/ext/commands/core.py77
2 files changed, 128 insertions, 2 deletions
diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py
index ef7d9ca6..53aecdb1 100644
--- a/discord/ext/commands/context.py
+++ b/discord/ext/commands/context.py
@@ -121,6 +121,59 @@ class Context(discord.abc.Messageable):
ret = yield from command.callback(*arguments, **kwargs)
return ret
+ @asyncio.coroutine
+ def reinvoke(self, *, call_hooks=False, restart=True):
+ """|coro|
+
+ Calls the command again.
+
+ This is similar to :meth:`~.Context.invoke` except that it bypasses
+ checks, cooldowns, and error handlers.
+
+ .. note::
+
+ If you want to bypass :exc:`.UserInputError` derived exceptions,
+ it is recommended to use the regular :meth:`~.Context.invoke`
+ as it will work more naturally. After all, this will end up
+ using the old arguments the user has used and will thus just
+ fail again.
+
+ Parameters
+ ------------
+ call_hooks: bool
+ Whether to call the before and after invoke hooks.
+ restart: bool
+ Whether to start the call chain from the very beginning
+ or where we left off (i.e. the command that caused the error).
+ """
+ cmd = self.command
+ view = self.view
+ if cmd is None:
+ raise ValueError('This context is not valid.')
+
+ # some state to revert to when we're done
+ index, previous = view.index, view.previous
+ invoked_with = self.invoked_with
+ invoked_subcommand = self.invoked_subcommand
+ subcommand_passed = self.subcommand_passed
+
+ if restart:
+ to_call = cmd.root_parent or cmd
+ view.index = len(self.prefix) + 1
+ view.previous = 0
+ else:
+ to_call = cmd
+
+ try:
+ yield from to_call.reinvoke(self, call_hooks=call_hooks)
+ finally:
+ self.command = cmd
+ view.index = index
+ view.previous = previous
+ self.invoked_with = invoked_with
+ self.invoked_subcommand = invoked_subcommand
+ self.subcommand_passed = subcommand_passed
+
@property
def valid(self):
"""Checks if the invocation context is valid to be invoked with."""
diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py
index 29a0433c..2662a31b 100644
--- a/discord/ext/commands/core.py
+++ b/discord/ext/commands/core.py
@@ -283,6 +283,24 @@ class Command:
return ' '.join(reversed(entries))
@property
+ def root_parent(self):
+ """Retrieves the root parent of this command.
+
+ If the command has no parents then it returns ``None``.
+
+ For example in commands ``?a b c test``, the root parent is
+ ``a``.
+ """
+ entries = []
+ command = self
+ while command.parent is not None:
+ command = command.parent
+ entries.append(command)
+ entries.append(None)
+ entries.reverse()
+ return entries[-1]
+
+ @property
def qualified_name(self):
"""Retrieves the fully qualified command name.
@@ -350,7 +368,6 @@ class Command:
if not view.eof:
raise TooManyArguments('Too many arguments passed to ' + self.qualified_name)
-
@asyncio.coroutine
def _verify_checks(self, ctx):
if not self.enabled:
@@ -407,7 +424,6 @@ class Command:
def prepare(self, ctx):
ctx.command = self
yield from self._verify_checks(ctx)
- yield from self._parse_arguments(ctx)
if self._buckets.valid:
bucket = self._buckets.get_bucket(ctx)
@@ -415,6 +431,7 @@ class Command:
if retry_after:
raise CommandOnCooldown(bucket, retry_after)
+ yield from self._parse_arguments(ctx)
yield from self.call_before_hooks(ctx)
def reset_cooldown(self, ctx):
@@ -440,6 +457,24 @@ class Command:
injected = hooked_wrapped_callback(self, ctx, self.callback)
yield from injected(*ctx.args, **ctx.kwargs)
+ @asyncio.coroutine
+ def reinvoke(self, ctx, *, call_hooks=False):
+ ctx.command = self
+ yield from self._parse_arguments(ctx)
+
+ if call_hooks:
+ yield from self.call_before_hooks(ctx)
+
+ ctx.invoked_subcommand = None
+ try:
+ yield from self.callback(*ctx.args, **ctx.kwargs)
+ except:
+ ctx.command_failed = True
+ raise
+ finally:
+ if call_hooks:
+ yield from self.call_after_hooks(ctx)
+
def error(self, coro):
"""A decorator that registers a coroutine as a local error handler.
@@ -821,6 +856,44 @@ class Group(GroupMixin, Command):
view.previous = previous
yield from super().invoke(ctx)
+ @asyncio.coroutine
+ def reinvoke(self, ctx, *, call_hooks=False):
+ early_invoke = not self.invoke_without_command
+ if early_invoke:
+ ctx.command = self
+ yield from self._parse_arguments(ctx)
+
+ if call_hooks:
+ yield from self.call_before_hooks(ctx)
+
+ view = ctx.view
+ previous = view.index
+ view.skip_ws()
+ trigger = view.get_word()
+
+ if trigger:
+ ctx.subcommand_passed = trigger
+ ctx.invoked_subcommand = self.all_commands.get(trigger, None)
+
+ if early_invoke:
+ try:
+ yield from self.callback(*ctx.args, **ctx.kwargs)
+ except:
+ ctx.command_failed = True
+ raise
+ finally:
+ if call_hooks:
+ yield from self.call_after_hooks(ctx)
+
+ if trigger and ctx.invoked_subcommand:
+ ctx.invoked_with = trigger
+ yield from ctx.invoked_subcommand.reinvoke(ctx, call_hooks=call_hooks)
+ elif not early_invoke:
+ # undo the trigger parsing
+ view.index = previous
+ view.previous = previous
+ yield from super().reinvoke(ctx, call_hooks=call_hooks)
+
# Decorators
def command(name=None, cls=None, **attrs):