aboutsummaryrefslogtreecommitdiff
path: root/discord
diff options
context:
space:
mode:
authorSigmath Bits <[email protected]>2021-04-10 18:50:59 +1200
committerGitHub <[email protected]>2021-04-10 02:50:59 -0400
commit68aef92b377f61ed465660646659d4ba0100c314 (patch)
treec5e2bfd811c9ceac60ed5ce9422012f94e482f25 /discord
parentmake examples on_ready consistent (diff)
downloaddiscord.py-68aef92b377f61ed465660646659d4ba0100c314.tar.xz
discord.py-68aef92b377f61ed465660646659d4ba0100c314.zip
[commands]Add typing.Literal converter
Diffstat (limited to 'discord')
-rw-r--r--discord/ext/commands/core.py95
-rw-r--r--discord/ext/commands/errors.py34
2 files changed, 99 insertions, 30 deletions
diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py
index a570ee48..cce6f30c 100644
--- a/discord/ext/commands/core.py
+++ b/discord/ext/commands/core.py
@@ -489,31 +489,52 @@ class Command(_BaseCommand):
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
async def do_conversion(self, ctx, converter, argument, param):
- try:
- origin = converter.__origin__
- except AttributeError:
- pass
- else:
- if origin is typing.Union:
- errors = []
- _NoneType = type(None)
- for conv in converter.__args__:
- # if we got to this part in the code, then the previous conversions have failed
- # so we should just undo the view, return the default, and allow parsing to continue
- # with the other parameters
- if conv is _NoneType and param.kind != param.VAR_POSITIONAL:
- ctx.view.undo()
- return None if param.default is param.empty else param.default
-
+ origin = typing.get_origin(converter)
+
+ if origin is typing.Union:
+ errors = []
+ _NoneType = type(None)
+ for conv in typing.get_args(converter):
+ # if we got to this part in the code, then the previous conversions have failed
+ # so we should just undo the view, return the default, and allow parsing to continue
+ # with the other parameters
+ if conv is _NoneType and param.kind != param.VAR_POSITIONAL:
+ ctx.view.undo()
+ return None if param.default is param.empty else param.default
+
+ try:
+ value = await self.do_conversion(ctx, conv, argument, param)
+ except CommandError as exc:
+ errors.append(exc)
+ else:
+ return value
+
+ # if we're here, then we failed all the converters
+ raise BadUnionArgument(param, typing.get_args(converter), errors)
+
+ if origin is typing.Literal:
+ errors = []
+ conversions = {}
+ literal_args = tuple(self._flattened_typing_literal_args(converter))
+ for literal in literal_args:
+ literal_type = type(literal)
+ try:
+ value = conversions[literal_type]
+ except KeyError:
try:
- value = await self._actual_conversion(ctx, conv, argument, param)
+ value = await self._actual_conversion(ctx, literal_type, argument, param)
except CommandError as exc:
errors.append(exc)
+ conversions[literal_type] = object()
+ continue
else:
- return value
+ conversions[literal_type] = value
+
+ if value == literal:
+ return value
- # if we're here, then we failed all the converters
- raise BadUnionArgument(param, converter.__args__, errors)
+ # if we're here, then we failed to match all the literals
+ raise BadLiteralArgument(param, literal_args, errors)
return await self._actual_conversion(ctx, converter, argument, param)
@@ -995,15 +1016,14 @@ class Command(_BaseCommand):
return ''
def _is_typing_optional(self, annotation):
- try:
- origin = annotation.__origin__
- except AttributeError:
- return False
+ return typing.get_origin(annotation) is typing.Union and typing.get_args(annotation)[-1] is type(None)
- if origin is not typing.Union:
- return False
-
- return annotation.__args__[-1] is type(None)
+ def _flattened_typing_literal_args(self, annotation):
+ for literal in typing.get_args(annotation):
+ if typing.get_origin(literal) is typing.Literal:
+ yield from self._flattened_typing_literal_args(literal)
+ else:
+ yield literal
@property
def signature(self):
@@ -1011,7 +1031,6 @@ class Command(_BaseCommand):
if self.usage is not None:
return self.usage
-
params = self.clean_params
if not params:
return ''
@@ -1019,6 +1038,22 @@ class Command(_BaseCommand):
result = []
for name, param in params.items():
greedy = isinstance(param.annotation, converters._Greedy)
+ optional = False # postpone evaluation of if it's an optional argument
+
+ # for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
+ # parameter signature is a literal list of it's values
+ annotation = param.annotation.converter if greedy else param.annotation
+ origin = typing.get_origin(annotation)
+ if not greedy and origin is typing.Union:
+ union_args = typing.get_args(annotation)
+ optional = union_args[-1] is type(None)
+ if optional:
+ annotation = union_args[0]
+ origin = typing.get_origin(annotation)
+
+ if origin is typing.Literal:
+ name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v)
+ for v in self._flattened_typing_literal_args(annotation))
if param.default is not param.empty:
# We don't want None or '' to trigger the [name=value] case and instead it should
@@ -1038,7 +1073,7 @@ class Command(_BaseCommand):
result.append(f'[{name}...]')
elif greedy:
result.append(f'[{name}]...')
- elif self._is_typing_optional(param.annotation):
+ elif optional:
result.append(f'[{name}]')
else:
result.append(f'<{name}>')
diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py
index f8a2724d..f7e745e0 100644
--- a/discord/ext/commands/errors.py
+++ b/discord/ext/commands/errors.py
@@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE.
"""
from discord.errors import ClientException, DiscordException
+import typing
__all__ = (
@@ -62,6 +63,7 @@ __all__ = (
'NSFWChannelRequired',
'ConversionError',
'BadUnionArgument',
+ 'BadLiteralArgument',
'ArgumentParsingError',
'UnexpectedQuoteError',
'InvalidEndOfQuotedStringError',
@@ -644,6 +646,8 @@ class BadUnionArgument(UserInputError):
try:
return x.__name__
except AttributeError:
+ if typing.get_origin(x) is not None:
+ return repr(x)
return x.__class__.__name__
to_string = [_get_name(x) for x in converters]
@@ -654,6 +658,36 @@ class BadUnionArgument(UserInputError):
super().__init__(f'Could not convert "{param.name}" into {fmt}.')
+class BadLiteralArgument(UserInputError):
+ """Exception raised when a :data:`typing.Literal` converter fails for all
+ its associated values.
+
+ This inherits from :exc:`UserInputError`
+
+ .. versionadded:: 2.0
+
+ Attributes
+ -----------
+ param: :class:`inspect.Parameter`
+ The parameter that failed being converted.
+ literals: Tuple[Any, ...]
+ A tuple of values compared against in conversion, in order of failure.
+ errors: List[:class:`CommandError`]
+ A list of errors that were caught from failing the conversion.
+ """
+ def __init__(self, param, literals, errors):
+ self.param = param
+ self.literals = literals
+ self.errors = errors
+
+ to_string = [repr(l) for l in literals]
+ if len(to_string) > 2:
+ fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1])
+ else:
+ fmt = ' or '.join(to_string)
+
+ super().__init__(f'Could not convert "{param.name}" into the literal {fmt}.')
+
class ArgumentParsingError(UserInputError):
"""An exception raised when the parser fails to parse a user's input.