diff options
Diffstat (limited to 'discord/ext/commands')
| -rw-r--r-- | discord/ext/commands/converter.py | 25 | ||||
| -rw-r--r-- | discord/ext/commands/core.py | 56 |
2 files changed, 79 insertions, 2 deletions
diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 880876a4..461abee7 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -34,7 +34,7 @@ __all__ = ['Converter', 'MemberConverter', 'UserConverter', 'TextChannelConverter', 'InviteConverter', 'RoleConverter', 'GameConverter', 'ColourConverter', 'VoiceChannelConverter', 'EmojiConverter', 'PartialEmojiConverter', 'CategoryChannelConverter', - 'IDConverter', 'clean_content'] + 'IDConverter', 'clean_content', 'Greedy'] def _get_from_guilds(bot, getter, argument): result = None @@ -483,3 +483,26 @@ class clean_content(Converter): # Completely ensure no mentions escape: return re.sub(r'@(everyone|here|[!&]?[0-9]{17,21})', '@\u200b\\1', result) + +class _Greedy: + __slots__ = ('converter',) + + def __init__(self, *, converter=None): + self.converter = converter + + def __getitem__(self, params): + if not isinstance(params, tuple): + params = (params,) + if len(params) != 1: + raise TypeError('Greedy[...] only takes a single argument') + converter = params[0] + + if not inspect.isclass(converter): + raise TypeError('Greedy[...] expects a type.') + + if converter is str or converter is type(None) or converter is _Greedy: + raise TypeError('Greedy[%s] is invalid.' % converter.__name__) + + return self.__class__(converter=converter) + +Greedy = _Greedy() diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index c64aeb9a..0aa79645 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -199,7 +199,11 @@ class Command: # be replaced with the real value for the converters to work later on for key, value in self.params.items(): if isinstance(value.annotation, str): - self.params[key] = value.replace(annotation=eval(value.annotation, function.__globals__)) + self.params[key] = value = value.replace(annotation=eval(value.annotation, function.__globals__)) + + # fail early for when someone passes an unparameterized Greedy type + if value.annotation is converters.Greedy: + raise TypeError('Unparameterized Greedy[...] is disallowed in signature.') async def dispatch_error(self, ctx, error): ctx.command_failed = True @@ -318,6 +322,19 @@ class Command: view = ctx.view view.skip_ws() + # The greedy converter is simple -- it keeps going until it fails in which case, + # it undos the view ready for the next parameter to use instead + if type(converter) is converters._Greedy: + if param.kind == param.POSITIONAL_OR_KEYWORD: + return await self._transform_greedy_pos(ctx, param, required, converter.converter) + elif param.kind == param.VAR_POSITIONAL: + return await self._transform_greedy_var_pos(ctx, param, converter.converter) + else: + # if we're here, then it's a KEYWORD_ONLY param type + # since this is mostly useless, we'll helpfully transform Greedy[X] + # into just X and do the parsing that way. + converter = converter.converter + if view.eof: if param.kind == param.VAR_POSITIONAL: raise RuntimeError() # break the loop @@ -334,6 +351,43 @@ class Command: return (await self.do_conversion(ctx, converter, argument, param)) + async def _transform_greedy_pos(self, ctx, param, required, converter): + view = ctx.view + result = [] + while not view.eof: + # for use with a manual undo + previous = view.index + + # parsing errors get propagated + view.skip_ws() + argument = quoted_word(view) + try: + value = await self.do_conversion(ctx, converter, argument, param) + except CommandError as e: + if not result: + if required: + raise + else: + view.index = previous + return param.default + view.index = previous + break + else: + result.append(value) + return result + + async def _transform_greedy_var_pos(self, ctx, param, converter): + view = ctx.view + previous = view.index + argument = quoted_word(view) + try: + value = await self.do_conversion(ctx, converter, argument, param) + except CommandError: + view.index = previous + raise RuntimeError() from None # break loop + else: + return value + @property def clean_params(self): """Retrieves the parameter OrderedDict without the context or self parameters. |